43
43
from ami .base .serializers import FilterParamsSerializer , SingleParamSerializer
44
44
from ami .base .views import ProjectMixin
45
45
from ami .jobs .models import DetectionClusteringJob , Job
46
- from ami .main .api .serializers import ClusterDetectionsSerializer
46
+ from ami .main .api .serializers import ClusterDetectionsSerializer , TagSerializer
47
47
from ami .utils .requests import get_active_classification_threshold , project_id_doc_param
48
48
from ami .utils .storages import ConnectionTestResult
49
49
63
63
SourceImage ,
64
64
SourceImageCollection ,
65
65
SourceImageUpload ,
66
+ Tag ,
66
67
TaxaList ,
67
68
Taxon ,
68
69
User ,
@@ -1187,6 +1188,29 @@ def filter_queryset(self, request, queryset, view):
1187
1188
return queryset
1188
1189
1189
1190
1191
+ class TaxonTagFilter (filters .BaseFilterBackend ):
1192
+ """FilterBackend that allows OR-based filtering of taxa by tag ID."""
1193
+
1194
+ def filter_queryset (self , request , queryset , view ):
1195
+ tag_ids = request .query_params .getlist ("tag_id" )
1196
+ if tag_ids :
1197
+ queryset = queryset .filter (tags__id__in = tag_ids ).distinct ()
1198
+ return queryset
1199
+
1200
+
1201
+ class TagInverseFilter (filters .BaseFilterBackend ):
1202
+ """
1203
+ Exclude taxa that have any of the specified tag IDs using `not_tag_id`.
1204
+ Example: /api/v2/taxa/?not_tag_id=1¬_tag_id=2
1205
+ """
1206
+
1207
+ def filter_queryset (self , request , queryset , view ):
1208
+ not_tag_ids = request .query_params .getlist ("not_tag_id" )
1209
+ if not_tag_ids :
1210
+ queryset = queryset .exclude (tags__id__in = not_tag_ids )
1211
+ return queryset .distinct ()
1212
+
1213
+
1190
1214
class TaxonViewSet (DefaultViewSet , ProjectMixin ):
1191
1215
"""
1192
1216
API endpoint that allows taxa to be viewed or edited.
@@ -1198,6 +1222,8 @@ class TaxonViewSet(DefaultViewSet, ProjectMixin):
1198
1222
CustomTaxonFilter ,
1199
1223
TaxonCollectionFilter ,
1200
1224
TaxonTaxaListFilter ,
1225
+ TaxonTagFilter ,
1226
+ TagInverseFilter ,
1201
1227
]
1202
1228
filterset_fields = [
1203
1229
"name" ,
@@ -1323,9 +1349,11 @@ def get_queryset(self) -> QuerySet:
1323
1349
"""
1324
1350
qs = super ().get_queryset ()
1325
1351
project = self .get_active_project ()
1352
+ qs = self .attach_tags_by_project (qs , project )
1326
1353
1327
1354
if project :
1328
1355
include_unobserved = True # Show detail views for unobserved taxa instead of 404
1356
+ # @TODO move to a QuerySet manager
1329
1357
qs = qs .annotate (
1330
1358
best_detection_image_path = models .Subquery (
1331
1359
Occurrence .objects .filter (
@@ -1416,6 +1444,43 @@ def get_taxa_observed(self, qs: QuerySet, project: Project, include_unobserved=F
1416
1444
)
1417
1445
return qs
1418
1446
1447
+ def attach_tags_by_project (self , qs : QuerySet , project : Project ) -> QuerySet :
1448
+ """
1449
+ Prefetch and override the `.tags` attribute on each Taxon
1450
+ with only the tags belonging to the given project.
1451
+ """
1452
+ # Include all tags if no project is passed
1453
+ if project is None :
1454
+ tag_qs = Tag .objects .all ()
1455
+ else :
1456
+ # Prefetch only the tags that belong to the project or are global
1457
+ tag_qs = Tag .objects .filter (models .Q (project = project ) | models .Q (project__isnull = True ))
1458
+
1459
+ tag_prefetch = Prefetch ("tags" , queryset = tag_qs , to_attr = "prefetched_tags" )
1460
+
1461
+ return qs .prefetch_related (tag_prefetch )
1462
+
1463
+ @action (detail = True , methods = ["post" ])
1464
+ def assign_tags (self , request , pk = None ):
1465
+ """
1466
+ Assign tags to a taxon
1467
+ """
1468
+ taxon = self .get_object ()
1469
+ tag_ids = request .data .get ("tag_ids" )
1470
+ logger .info (f"Tag IDs: { tag_ids } " )
1471
+ if not isinstance (tag_ids , list ):
1472
+ return Response ({"detail" : "tag_ids must be a list of IDs." }, status = status .HTTP_400_BAD_REQUEST )
1473
+
1474
+ tags = Tag .objects .filter (id__in = tag_ids )
1475
+ logger .info (f"Tags: { tags } , len: { len (tags )} " )
1476
+ taxon .tags .set (tags ) # replaces all tags for this taxon
1477
+ taxon .save ()
1478
+ logger .info (f"Tags after assingment : { len (taxon .tags .all ())} " )
1479
+ return Response (
1480
+ {"taxon_id" : taxon .id , "assigned_tag_ids" : [tag .pk for tag in tags ]},
1481
+ status = status .HTTP_200_OK ,
1482
+ )
1483
+
1419
1484
@extend_schema (parameters = [project_id_doc_param ])
1420
1485
def list (self , request , * args , ** kwargs ):
1421
1486
return super ().list (request , * args , ** kwargs )
@@ -1434,6 +1499,20 @@ def get_queryset(self):
1434
1499
serializer_class = TaxaListSerializer
1435
1500
1436
1501
1502
+ class TagViewSet (DefaultViewSet , ProjectMixin ):
1503
+ queryset = Tag .objects .all ()
1504
+ serializer_class = TagSerializer
1505
+ filterset_fields = ["taxa" ]
1506
+
1507
+ def get_queryset (self ):
1508
+ qs = super ().get_queryset ()
1509
+ project = self .get_active_project ()
1510
+ if project :
1511
+ # Filter by project, but also include global tags
1512
+ return qs .filter (models .Q (project = project ) | models .Q (project__isnull = True ))
1513
+ return qs
1514
+
1515
+
1437
1516
class ClassificationViewSet (DefaultViewSet , ProjectMixin ):
1438
1517
"""
1439
1518
API endpoint for viewing and adding classification results from a model.
0 commit comments