Skip to content

Commit edac77b

Browse files
Support for clustering detections (#818)
* feat: Added pgvector extension * feat: Added features field to the Classification model * changed taxon and detection to autocomplete fields in the ClassificationAdmin model * feat: added similar action to the ClassificationViewset * chore: changed features vector field name to features_2048 * chore: changed features vector field name to features_2048 * feat: read features vector from processing service ClassificationResponse and save it to Classification object * test: added tests for PGVector distance metrics * updated docker-compose.ci.yml to use the same postgres image * updated docker-compose.ci.yml to use the same postgres image as docker-compose.yml * updated docker-compose.ci.yml to use the same postgres image as docker-compose.yml * feat: Added support for clustering detections for source image collections * feat: Allowed triggering collection detections clustering from admin page * fix: show unobserved Taxa in view for now * fix: create & update occurrence determinations after clustering * feat: add unknown species filter to admin * fix: circular import * fix: update migration ordering * Integrated Agglomerative clustering * updated clustering request params * fixed Agglomerative clustering * fix: disable missing clustering algorithms * fix: syntax when creating algorithm entry * feat: command to create clustering job without starting it * feat: increase default batch size * fix: better algorithm name * feat: allow sorting by OOD score * feat: add unknown species and other fields to Taxon serializer * fix: remove missing field * fix: migration conflicts * feat: fields for investigating occurrence classifications in admin * fix: filter by feature extraction algorithm * chore: Used a serializer to handle job params instead of reading them directly from the request objects * set default ood threshold to 0.0 * test: added tests for clustering * chore: migration for new algorithm type * fix: remove cluster action in Event admin until its ready * chore: move algorithm selection to dedicated function * fix: update clustering tests and types * chore: remove external network config in processing services * feat: update GitHub workflows to run tests on other branches * fix: hide unobserved taxa by default. todo: add frontend filter to toggle this --------- Co-authored-by: Michael Bunsen <notbot@gmail.com>
1 parent 2b1a470 commit edac77b

24 files changed

+821
-17
lines changed

.github/workflows/test.backend.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ env:
77

88
on:
99
pull_request:
10-
branches: ["master", "main"]
10+
branches: ["main", "deployments/*", "releases/*"]
1111
paths-ignore: ["docs/**", "ui/**"]
1212

1313
push:
14-
branches: ["master", "main"]
14+
branches: ["main", "deployments/*", "releases/*"]
1515
paths-ignore: ["docs/**", "ui/**"]
1616

1717
concurrency:

.github/workflows/test.frontend.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ env:
77

88
on:
99
pull_request:
10-
branches: ["master", "main"]
10+
branches: ["main", "deployments/*", "releases/*"]
1111
paths:
1212
- "!./**"
1313
- "ui/**"
1414

1515
push:
16-
branches: ["master", "main"]
16+
branches: ["main", "deployments/*", "releases/*"]
1717
paths:
1818
- "!./**"
1919
- "ui/**"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Generated by Django 4.2.10 on 2025-04-24 16:25
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("jobs", "0016_job_data_export_job_params_alter_job_job_type_key"),
9+
]
10+
11+
operations = [
12+
migrations.AlterField(
13+
model_name="job",
14+
name="job_type_key",
15+
field=models.CharField(
16+
choices=[
17+
("ml", "ML pipeline"),
18+
("populate_captures_collection", "Populate captures collection"),
19+
("data_storage_sync", "Data storage sync"),
20+
("unknown", "Unknown"),
21+
("data_export", "Data Export"),
22+
("occurrence_clustering", "Occurrence Feature Clustering"),
23+
],
24+
default="unknown",
25+
max_length=255,
26+
verbose_name="Job Type",
27+
),
28+
),
29+
]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Generated by Django 4.2.10 on 2025-04-28 11:06
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("jobs", "0017_alter_job_job_type_key"),
9+
]
10+
11+
operations = [
12+
migrations.AlterField(
13+
model_name="job",
14+
name="job_type_key",
15+
field=models.CharField(
16+
choices=[
17+
("ml", "ML pipeline"),
18+
("populate_captures_collection", "Populate captures collection"),
19+
("data_storage_sync", "Data storage sync"),
20+
("unknown", "Unknown"),
21+
("data_export", "Data Export"),
22+
("detection_clustering", "Detection Feature Clustering"),
23+
],
24+
default="unknown",
25+
max_length=255,
26+
verbose_name="Job Type",
27+
),
28+
),
29+
]

ami/jobs/models.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ def run(cls, job: "Job"):
400400
total_classifications = 0
401401

402402
config = job.pipeline.get_config(project_id=job.project.pk)
403-
chunk_size = config.get("request_source_image_batch_size", 1)
403+
chunk_size = config.get("request_source_image_batch_size", 2)
404+
# @TODO Ensure only images of the same dimensions are processed in a batch
404405
chunks = [images[i : i + chunk_size] for i in range(0, image_count, chunk_size)] # noqa
405406
request_failed_images = []
406407

@@ -639,6 +640,38 @@ def run(cls, job: "Job"):
639640
job.update_status(JobState.SUCCESS, save=True)
640641

641642

643+
class DetectionClusteringJob(JobType):
644+
name = "Detection Feature Clustering"
645+
key = "detection_clustering"
646+
647+
@classmethod
648+
def run(cls, job: "Job"):
649+
job.update_status(JobState.STARTED)
650+
job.started_at = datetime.datetime.now()
651+
job.finished_at = None
652+
job.progress.add_stage(name="Collecting Features", key="feature_collection")
653+
job.progress.add_stage("Clustering", key="clustering")
654+
job.progress.add_stage("Creating Unknown Taxa", key="create_unknown_taxa")
655+
job.save()
656+
657+
if not job.source_image_collection:
658+
raise ValueError("No source image collection provided")
659+
660+
job.logger.info(f"Clustering detections for collection {job.source_image_collection}")
661+
job.update_status(JobState.STARTED)
662+
job.started_at = datetime.datetime.now()
663+
job.finished_at = None
664+
job.save()
665+
666+
# Call the clustering method
667+
job.source_image_collection.cluster_detections(job=job)
668+
job.logger.info(f"Finished clustering detections for collection {job.source_image_collection}")
669+
670+
job.finished_at = datetime.datetime.now()
671+
job.update_status(JobState.SUCCESS, save=False)
672+
job.save()
673+
674+
642675
class UnknownJobType(JobType):
643676
name = "Unknown"
644677
key = "unknown"
@@ -648,7 +681,14 @@ def run(cls, job: "Job"):
648681
raise ValueError(f"Unknown job type '{job.job_type()}'")
649682

650683

651-
VALID_JOB_TYPES = [MLJob, SourceImageCollectionPopulateJob, DataStorageSyncJob, UnknownJobType, DataExportJob]
684+
VALID_JOB_TYPES = [
685+
MLJob,
686+
SourceImageCollectionPopulateJob,
687+
DataStorageSyncJob,
688+
UnknownJobType,
689+
DataExportJob,
690+
DetectionClusteringJob,
691+
]
652692

653693

654694
def get_job_type_by_key(key: str) -> type[JobType] | None:

ami/main/admin.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from django.http.request import HttpRequest
77
from django.template.defaultfilters import filesizeformat
88
from django.utils.formats import number_format
9+
from django.utils.html import format_html
910
from guardian.admin import GuardedModelAdmin
1011

1112
import ami.utils
@@ -220,7 +221,6 @@ def update_calculated_fields(self, request: HttpRequest, queryset: QuerySet[Even
220221
self.message_user(request, f"Updated {queryset.count()} events.")
221222

222223
list_filter = ("deployment", "project", "start")
223-
actions = [update_calculated_fields]
224224

225225

226226
@admin.register(SourceImage)
@@ -262,20 +262,27 @@ class ClassificationInline(admin.TabularInline):
262262
model = Classification
263263
extra = 0
264264
fields = (
265+
"view_classification",
265266
"taxon",
266267
"algorithm",
267268
"timestamp",
268269
"terminal",
269270
"created_at",
270271
)
271272
readonly_fields = (
273+
"view_classification",
272274
"taxon",
273275
"algorithm",
274276
"timestamp",
275277
"terminal",
276278
"created_at",
277279
)
278280

281+
@admin.display(description="Classification")
282+
def view_classification(self, obj):
283+
url = f"/admin/main/classification/{obj.pk}/change/"
284+
return format_html('<a href="{}">{}</a>', url, obj.pk)
285+
279286
def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
280287
qs = super().get_queryset(request)
281288
return qs.select_related("taxon", "algorithm", "detection")
@@ -285,20 +292,27 @@ class DetectionInline(admin.TabularInline):
285292
model = Detection
286293
extra = 0
287294
fields = (
295+
"view_detection",
288296
"detection_algorithm",
289297
"source_image",
290298
"timestamp",
291299
"created_at",
292300
"occurrence",
293301
)
294302
readonly_fields = (
303+
"view_detection",
295304
"detection_algorithm",
296305
"source_image",
297306
"timestamp",
298307
"created_at",
299308
"occurrence",
300309
)
301310

311+
@admin.display(description="Detection")
312+
def view_detection(self, obj):
313+
url = f"/admin/main/detection/{obj.pk}/change/"
314+
return format_html('<a href="{}">{}</a>', url, obj.pk)
315+
302316

303317
@admin.register(Detection)
304318
class DetectionAdmin(admin.ModelAdmin[Detection]):
@@ -461,7 +475,7 @@ class TaxonAdmin(admin.ModelAdmin[Taxon]):
461475
"created_at",
462476
"updated_at",
463477
)
464-
list_filter = ("lists", "rank", TaxonParentFilter)
478+
list_filter = ("unknown_species", "lists", "rank", TaxonParentFilter)
465479
search_fields = ("name",)
466480
autocomplete_fields = (
467481
"parent",
@@ -594,7 +608,48 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou
594608
f"Populating {len(queued_tasks)} collection(s) background tasks: {queued_tasks}.",
595609
)
596610

597-
actions = [populate_collection, populate_collection_async]
611+
@admin.action(description="Create clustering job (but don't run it)")
612+
@admin.action()
613+
def create_clustering_job(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
614+
from ami.jobs.models import DetectionClusteringJob, Job
615+
616+
for collection in queryset:
617+
job = Job.objects.create(
618+
name=f"Clustering detections for collection {collection.pk}",
619+
project=collection.project,
620+
source_image_collection=collection,
621+
job_type_key=DetectionClusteringJob.key,
622+
params={
623+
"ood_threshold": 0.3,
624+
"algorithm": "agglomerative",
625+
"algorithm_kwargs": {"distance_threshold": 80},
626+
"pca": {"n_components": 384},
627+
},
628+
)
629+
self.message_user(request, f"Created clustering job #{job.pk} for collection #{collection.pk}")
630+
631+
@admin.action()
632+
def cluster_detections(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
633+
for collection in queryset:
634+
from ami.jobs.models import DetectionClusteringJob, Job
635+
636+
job = Job.objects.create(
637+
name=f"Clustering detections for collection {collection.pk}",
638+
project=collection.project,
639+
source_image_collection=collection,
640+
job_type_key=DetectionClusteringJob.key,
641+
params={
642+
"ood_threshold": 0.3,
643+
"algorithm": "agglomerative",
644+
"algorithm_kwargs": {"distance_threshold": 80},
645+
"pca": {"n_components": 384},
646+
},
647+
)
648+
job.enqueue()
649+
650+
self.message_user(request, f"Clustered {queryset.count()} collection(s).")
651+
652+
actions = [populate_collection, populate_collection_async, cluster_detections, create_clustering_job]
598653

599654
# Hide images many-to-many field from form. This would list all source images in the database.
600655
exclude = ("images",)

ami/main/api/serializers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class Meta:
518518
"last_detected",
519519
"best_determination_score",
520520
"cover_image_url",
521+
"unknown_species",
521522
"created_at",
522523
"updated_at",
523524
]
@@ -740,6 +741,8 @@ class Meta:
740741
"fieldguide_id",
741742
"cover_image_url",
742743
"cover_image_credit",
744+
"unknown_species",
745+
"last_detected", # @TODO this has performance impact, review
743746
]
744747

745748

@@ -1548,3 +1551,11 @@ class Meta:
15481551
"total_size",
15491552
"last_checked",
15501553
]
1554+
1555+
1556+
class ClusterDetectionsSerializer(serializers.Serializer):
1557+
ood_threshold = serializers.FloatField(required=False, default=0.0)
1558+
feature_extraction_algorithm = serializers.CharField(required=False, allow_null=True)
1559+
algorithm = serializers.CharField(required=False, default="agglomerative")
1560+
algorithm_kwargs = serializers.DictField(required=False, default={"distance_threshold": 0.5})
1561+
pca = serializers.DictField(required=False, default={"n_components": 384})

ami/main/api/views.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
)
4343
from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer
4444
from ami.base.views import ProjectMixin
45+
from ami.jobs.models import DetectionClusteringJob, Job
46+
from ami.main.api.serializers import ClusterDetectionsSerializer
4547
from ami.utils.requests import get_active_classification_threshold, project_id_doc_param
4648
from ami.utils.storages import ConnectionTestResult
4749

@@ -744,6 +746,27 @@ def remove(self, request, pk=None):
744746
}
745747
)
746748

749+
@action(detail=True, methods=["post"], name="cluster detections")
750+
def cluster_detections(self, request, pk=None):
751+
"""
752+
Trigger a background job to cluster detections from this collection.
753+
"""
754+
755+
collection: SourceImageCollection = self.get_object()
756+
serializer = ClusterDetectionsSerializer(data=request.data)
757+
serializer.is_valid(raise_exception=True)
758+
params = serializer.validated_data
759+
job = Job.objects.create(
760+
name=f"Clustering detections for collection {collection.pk}",
761+
project=collection.project,
762+
source_image_collection=collection,
763+
job_type_key=DetectionClusteringJob.key,
764+
params=params,
765+
)
766+
job.enqueue()
767+
logger.info(f"Triggered clustering job for collection {collection.pk}")
768+
return Response({"job_id": job.pk, "project_id": collection.project.pk})
769+
747770
@extend_schema(parameters=[project_id_doc_param])
748771
def list(self, request, *args, **kwargs):
749772
return super().list(request, *args, **kwargs)
@@ -1273,8 +1296,7 @@ def get_queryset(self) -> QuerySet:
12731296
project = self.get_active_project()
12741297

12751298
if project:
1276-
# Allow showing detail views for unobserved taxa
1277-
include_unobserved = True
1299+
include_unobserved = True # Show detail views for unobserved taxa instead of 404
12781300
if self.action == "list":
12791301
include_unobserved = self.request.query_params.get("include_unobserved", False)
12801302
qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Generated by Django 4.2.10 on 2025-04-28 11:11
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("main", "0062_classification_ood_score_and_more"),
9+
]
10+
11+
operations = [
12+
migrations.AddField(
13+
model_name="taxon",
14+
name="unknown_species",
15+
field=models.BooleanField(default=False, help_text="Is this a clustering-generated taxon"),
16+
),
17+
]

ami/main/migrations/0060_taxon_cover_image_credit_taxon_cover_image_url_and_more.py renamed to ami/main/migrations/0064_taxon_cover_image_credit_taxon_cover_image_url_and_more.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class Migration(migrations.Migration):
77
dependencies = [
8-
("main", "0059_alter_project_options"),
8+
("main", "0063_taxon_unknown_species"),
99
]
1010

1111
operations = [

0 commit comments

Comments
 (0)