Skip to content

Commit 2b27451

Browse files
committed
Merge branch 'deployments/ood.antenna.insectai.org' of github.com:RolnickLab/antenna into feat/better-cluster-data
2 parents 27f6e50 + 5302486 commit 2b27451

File tree

21 files changed

+512
-46
lines changed

21 files changed

+512
-46
lines changed

ami/main/admin.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Site,
2828
SourceImage,
2929
SourceImageCollection,
30+
Tag,
3031
TaxaList,
3132
Taxon,
3233
)
@@ -471,6 +472,7 @@ class TaxonAdmin(admin.ModelAdmin[Taxon]):
471472
"rank",
472473
"parent",
473474
"parent_names",
475+
"tag_list",
474476
"list_names",
475477
"created_at",
476478
"updated_at",
@@ -491,10 +493,10 @@ def get_queryset(self, request):
491493

492494
return qs.annotate(occurrence_count=models.Count("occurrences")).order_by("-occurrence_count")
493495

494-
@admin.display(
495-
description="Occurrences",
496-
ordering="occurrence_count",
497-
)
496+
@admin.display(description="Tags")
497+
def tag_list(self, obj) -> str:
498+
return ", ".join([tag.name for tag in obj.tags.all()])
499+
498500
def occurrence_count(self, obj) -> int:
499501
return obj.occurrence_count
500502

@@ -653,3 +655,10 @@ def cluster_detections(self, request: HttpRequest, queryset: QuerySet[SourceImag
653655

654656
# Hide images many-to-many field from form. This would list all source images in the database.
655657
exclude = ("images",)
658+
659+
660+
@admin.register(Tag)
661+
class TagAdmin(admin.ModelAdmin):
662+
list_display = ("id", "name", "project")
663+
list_filter = ("project",)
664+
search_fields = ("name",)

ami/main/api/serializers.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ami.base.fields import DateStringField
1010
from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer, get_current_user, reverse_with_params
1111
from ami.jobs.models import Job
12-
from ami.main.models import create_source_image_from_upload
12+
from ami.main.models import Tag, create_source_image_from_upload
1313
from ami.ml.models import Algorithm
1414
from ami.ml.serializers import AlgorithmSerializer
1515
from ami.users.models import User
@@ -500,12 +500,33 @@ class Meta:
500500
]
501501

502502

503+
class TagSerializer(DefaultSerializer):
504+
project = ProjectNestedSerializer(read_only=True)
505+
project_id = serializers.PrimaryKeyRelatedField(queryset=Project.objects.all(), source="project", write_only=True)
506+
taxa_ids = serializers.PrimaryKeyRelatedField(
507+
queryset=Taxon.objects.all(), many=True, source="taxa", write_only=True, required=False
508+
)
509+
taxa = serializers.SerializerMethodField()
510+
511+
class Meta:
512+
model = Tag
513+
fields = ["id", "name", "project", "project_id", "taxa_ids", "taxa"]
514+
515+
def get_taxa(self, obj):
516+
return [{"id": taxon.id, "name": taxon.name} for taxon in obj.taxa.all()]
517+
518+
503519
class TaxonListSerializer(DefaultSerializer):
504520
# latest_detection = DetectionNestedSerializer(read_only=True)
505521
occurrences = serializers.SerializerMethodField()
506522
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
507523
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent")
508524
cover_image_url = serializers.SerializerMethodField()
525+
tags = serializers.SerializerMethodField()
526+
527+
def get_tags(self, obj):
528+
tag_list = getattr(obj, "prefetched_tags", [])
529+
return TagSerializer(tag_list, many=True, context=self.context).data
509530

510531
class Meta:
511532
model = Taxon
@@ -518,6 +539,7 @@ class Meta:
518539
"details",
519540
"occurrences_count",
520541
"occurrences",
542+
"tags",
521543
"last_detected",
522544
"best_determination_score",
523545
"cover_image_url",
@@ -737,6 +759,12 @@ class TaxonSerializer(DefaultSerializer):
737759
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent", write_only=True)
738760
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
739761
cover_image_url = serializers.SerializerMethodField()
762+
tags = serializers.SerializerMethodField()
763+
764+
def get_tags(self, obj):
765+
# Use prefetched tags
766+
tag_list = getattr(obj, "prefetched_tags", [])
767+
return TagSerializer(tag_list, many=True, context=self.context).data
740768

741769
class Meta:
742770
model = Taxon
@@ -752,12 +780,12 @@ class Meta:
752780
"events_count",
753781
"occurrences",
754782
"gbif_taxon_key",
783+
"tags",
755784
"last_detected",
756785
"fieldguide_id",
757786
"cover_image_url",
758787
"cover_image_credit",
759788
"unknown_species",
760-
"last_detected", # @TODO this has performance impact, review
761789
]
762790

763791
def get_cover_image_url(self, obj):

ami/main/api/views.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer
4444
from ami.base.views import ProjectMixin
4545
from ami.jobs.models import DetectionClusteringJob, Job
46-
from ami.main.api.serializers import ClusterDetectionsSerializer
46+
from ami.main.api.serializers import ClusterDetectionsSerializer, TagSerializer
4747
from ami.utils.requests import get_active_classification_threshold, project_id_doc_param
4848
from ami.utils.storages import ConnectionTestResult
4949

@@ -63,6 +63,7 @@
6363
SourceImage,
6464
SourceImageCollection,
6565
SourceImageUpload,
66+
Tag,
6667
TaxaList,
6768
Taxon,
6869
User,
@@ -1187,6 +1188,29 @@ def filter_queryset(self, request, queryset, view):
11871188
return queryset
11881189

11891190

1191+
class TaxonTagFilter(filters.BaseFilterBackend):
1192+
"""FilterBackend that allows OR-based filtering of taxa by tag ID."""
1193+
1194+
def filter_queryset(self, request, queryset, view):
1195+
tag_ids = request.query_params.getlist("tag_id")
1196+
if tag_ids:
1197+
queryset = queryset.filter(tags__id__in=tag_ids).distinct()
1198+
return queryset
1199+
1200+
1201+
class TagInverseFilter(filters.BaseFilterBackend):
1202+
"""
1203+
Exclude taxa that have any of the specified tag IDs using `not_tag_id`.
1204+
Example: /api/v2/taxa/?not_tag_id=1&not_tag_id=2
1205+
"""
1206+
1207+
def filter_queryset(self, request, queryset, view):
1208+
not_tag_ids = request.query_params.getlist("not_tag_id")
1209+
if not_tag_ids:
1210+
queryset = queryset.exclude(tags__id__in=not_tag_ids)
1211+
return queryset.distinct()
1212+
1213+
11901214
class TaxonViewSet(DefaultViewSet, ProjectMixin):
11911215
"""
11921216
API endpoint that allows taxa to be viewed or edited.
@@ -1198,6 +1222,8 @@ class TaxonViewSet(DefaultViewSet, ProjectMixin):
11981222
CustomTaxonFilter,
11991223
TaxonCollectionFilter,
12001224
TaxonTaxaListFilter,
1225+
TaxonTagFilter,
1226+
TagInverseFilter,
12011227
]
12021228
filterset_fields = [
12031229
"name",
@@ -1323,9 +1349,11 @@ def get_queryset(self) -> QuerySet:
13231349
"""
13241350
qs = super().get_queryset()
13251351
project = self.get_active_project()
1352+
qs = self.attach_tags_by_project(qs, project)
13261353

13271354
if project:
13281355
include_unobserved = True # Show detail views for unobserved taxa instead of 404
1356+
# @TODO move to a QuerySet manager
13291357
qs = qs.annotate(
13301358
best_detection_image_path=models.Subquery(
13311359
Occurrence.objects.filter(
@@ -1416,6 +1444,43 @@ def get_taxa_observed(self, qs: QuerySet, project: Project, include_unobserved=F
14161444
)
14171445
return qs
14181446

1447+
def attach_tags_by_project(self, qs: QuerySet, project: Project) -> QuerySet:
1448+
"""
1449+
Prefetch and override the `.tags` attribute on each Taxon
1450+
with only the tags belonging to the given project.
1451+
"""
1452+
# Include all tags if no project is passed
1453+
if project is None:
1454+
tag_qs = Tag.objects.all()
1455+
else:
1456+
# Prefetch only the tags that belong to the project or are global
1457+
tag_qs = Tag.objects.filter(models.Q(project=project) | models.Q(project__isnull=True))
1458+
1459+
tag_prefetch = Prefetch("tags", queryset=tag_qs, to_attr="prefetched_tags")
1460+
1461+
return qs.prefetch_related(tag_prefetch)
1462+
1463+
@action(detail=True, methods=["post"])
1464+
def assign_tags(self, request, pk=None):
1465+
"""
1466+
Assign tags to a taxon
1467+
"""
1468+
taxon = self.get_object()
1469+
tag_ids = request.data.get("tag_ids")
1470+
logger.info(f"Tag IDs: {tag_ids}")
1471+
if not isinstance(tag_ids, list):
1472+
return Response({"detail": "tag_ids must be a list of IDs."}, status=status.HTTP_400_BAD_REQUEST)
1473+
1474+
tags = Tag.objects.filter(id__in=tag_ids)
1475+
logger.info(f"Tags: {tags}, len: {len(tags)}")
1476+
taxon.tags.set(tags) # replaces all tags for this taxon
1477+
taxon.save()
1478+
logger.info(f"Tags after assingment : {len(taxon.tags.all())}")
1479+
return Response(
1480+
{"taxon_id": taxon.id, "assigned_tag_ids": [tag.pk for tag in tags]},
1481+
status=status.HTTP_200_OK,
1482+
)
1483+
14191484
@extend_schema(parameters=[project_id_doc_param])
14201485
def list(self, request, *args, **kwargs):
14211486
return super().list(request, *args, **kwargs)
@@ -1434,6 +1499,20 @@ def get_queryset(self):
14341499
serializer_class = TaxaListSerializer
14351500

14361501

1502+
class TagViewSet(DefaultViewSet, ProjectMixin):
1503+
queryset = Tag.objects.all()
1504+
serializer_class = TagSerializer
1505+
filterset_fields = ["taxa"]
1506+
1507+
def get_queryset(self):
1508+
qs = super().get_queryset()
1509+
project = self.get_active_project()
1510+
if project:
1511+
# Filter by project, but also include global tags
1512+
return qs.filter(models.Q(project=project) | models.Q(project__isnull=True))
1513+
return qs
1514+
1515+
14371516
class ClassificationViewSet(DefaultViewSet, ProjectMixin):
14381517
"""
14391518
API endpoint for viewing and adding classification results from a model.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Generated by Django 4.2.10 on 2025-05-15 21:23
2+
3+
from django.db import migrations, models
4+
import django.db.models.deletion
5+
6+
7+
def add_inital_tags(apps, schema_editor):
8+
"""
9+
Add initial tags to the database.
10+
"""
11+
Tag = apps.get_model("main", "Tag")
12+
13+
# Make tags available for all projects
14+
project = None
15+
16+
# Create initial tags
17+
tags = [
18+
"most wanted",
19+
"also wanted",
20+
"reviewed",
21+
"collected",
22+
]
23+
24+
for tag in tags:
25+
Tag.objects.get_or_create(name=tag, project=project)
26+
27+
28+
class Migration(migrations.Migration):
29+
dependencies = [
30+
("main", "0066_populate_cached_occurence_fields"),
31+
]
32+
33+
operations = [
34+
migrations.CreateModel(
35+
name="Tag",
36+
fields=[
37+
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
38+
("created_at", models.DateTimeField(auto_now_add=True)),
39+
("updated_at", models.DateTimeField(auto_now=True)),
40+
("name", models.CharField(max_length=255)),
41+
(
42+
"project",
43+
models.ForeignKey(
44+
blank=True,
45+
null=True,
46+
on_delete=django.db.models.deletion.CASCADE,
47+
related_name="tags",
48+
to="main.project",
49+
),
50+
),
51+
],
52+
options={
53+
"unique_together": {("name", "project")},
54+
},
55+
),
56+
migrations.AddField(
57+
model_name="taxon",
58+
name="tags",
59+
field=models.ManyToManyField(blank=True, related_name="taxa", to="main.tag"),
60+
),
61+
migrations.RunPython(add_inital_tags, migrations.RunPython.noop),
62+
]

ami/main/models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class Project(BaseModel):
142142
devices: models.QuerySet["Device"]
143143
sites: models.QuerySet["Site"]
144144
jobs: models.QuerySet["Job"]
145+
tags: models.QuerySet["Tag"]
145146

146147
objects = ProjectManager()
147148

@@ -2906,6 +2907,7 @@ class Taxon(BaseModel):
29062907
authorship_date = models.DateField(null=True, blank=True, help_text="The date the taxon was described.")
29072908
ordering = models.IntegerField(null=True, blank=True)
29082909
sort_phylogeny = models.BigIntegerField(blank=True, null=True)
2910+
tags = models.ManyToManyField("Tag", related_name="taxa", blank=True)
29092911
unknown_species = models.BooleanField(default=False, help_text="Is this a clustering-generated taxon")
29102912
objects: TaxonManager = TaxonManager()
29112913

@@ -3166,6 +3168,19 @@ class Meta:
31663168
verbose_name_plural = "Taxa Lists"
31673169

31683170

3171+
@final
3172+
class Tag(BaseModel):
3173+
"""A tag for taxa"""
3174+
3175+
name = models.CharField(max_length=255)
3176+
project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name="tags", null=True, blank=True)
3177+
3178+
taxa: models.QuerySet[Taxon]
3179+
3180+
class Meta:
3181+
unique_together = ("name", "project")
3182+
3183+
31693184
@final
31703185
class BlogPost(BaseModel):
31713186
"""

config/api_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
router.register(r"occurrences", views.OccurrenceViewSet)
3131
router.register(r"taxa/lists", views.TaxaListViewSet)
3232
router.register(r"taxa", views.TaxonViewSet)
33+
router.register(r"tags", views.TagViewSet)
3334
router.register(r"ml/algorithms", ml_views.AlgorithmViewSet)
3435
router.register(r"ml/labels", ml_views.AlgorithmCategoryMapViewSet)
3536
router.register(r"ml/pipelines", ml_views.PipelineViewSet)

ui/src/components/filtering/filter-control.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { ScoreFilter } from './filters/score-filter'
1111
import { SessionFilter } from './filters/session-filter'
1212
import { StationFilter } from './filters/station-filter'
1313
import { StatusFilter } from './filters/status-filter'
14+
import { TagFilter } from './filters/tag-filter'
1415
import { TaxaListFilter } from './filters/taxa-list-filter'
1516
import { TaxonFilter } from './filters/taxon-filter'
1617
import { TypeFilter } from './filters/type-filter'
@@ -36,6 +37,8 @@ const ComponentMap: {
3637
source_image_collection: CollectionFilter,
3738
source_image_single: ImageFilter,
3839
status: StatusFilter,
40+
tag_id: TagFilter,
41+
not_tag_id: TagFilter,
3942
taxon: TaxonFilter,
4043
taxa_list_id: TaxaListFilter,
4144
verified_by_me: VerifiedByFilter,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { Select } from 'nova-ui-kit'
2+
import { FilterProps } from './types'
3+
4+
export const TagFilter = ({ data = [], value, onAdd }: FilterProps) => {
5+
const tags = data as { id: number; name: string }[]
6+
7+
return (
8+
<Select.Root value={value ?? ''} onValueChange={onAdd}>
9+
<Select.Trigger>
10+
<Select.Value placeholder="Select a value" />
11+
</Select.Trigger>
12+
<Select.Content className="max-h-72">
13+
{tags.map((option) => (
14+
<Select.Item key={option.id} value={`${option.id}`}>
15+
{option.name}
16+
</Select.Item>
17+
))}
18+
</Select.Content>
19+
</Select.Root>
20+
)
21+
}

0 commit comments

Comments
 (0)