Skip to content

Commit 67430bb

Browse files
feat: add project default taxa filter to OccurrenceViewSet, TaxonViewSet, SourceImageViewSet, and SummaryView
1 parent 2312da0 commit 67430bb

File tree

2 files changed

+150
-14
lines changed

2 files changed

+150
-14
lines changed

ami/main/api/views.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,9 @@ def get_queryset(self) -> QuerySet:
494494

495495
classification_threshold = get_default_classification_threshold(project, self.request)
496496
queryset = queryset.with_occurrences_count( # type: ignore
497-
classification_threshold=classification_threshold
497+
classification_threshold=classification_threshold, project=project
498498
).with_taxa_count( # type: ignore
499-
classification_threshold=classification_threshold
499+
classification_threshold=classification_threshold, project=project
500500
)
501501

502502
queryset.select_related(
@@ -1109,6 +1109,7 @@ def get_queryset(self) -> QuerySet["Occurrence"]:
11091109
qs = qs.with_detections_count().with_timestamps() # type: ignore
11101110
qs = qs.with_identifications() # type: ignore
11111111
qs = qs.filter_by_score_threshold(project, self.request) # type: ignore
1112+
qs = qs.filter_by_project_default_taxa(project, self.request) # type: ignore
11121113
if self.action != "list":
11131114
qs = qs.prefetch_related(
11141115
Prefetch(
@@ -1340,6 +1341,8 @@ def get_queryset(self) -> QuerySet:
13401341
qs = self.attach_tags_by_project(qs, project)
13411342

13421343
if project:
1344+
# Filter by project default taxa
1345+
qs = qs.filter_by_project_default_taxa(project, self.request) # type: ignore
13431346
# Allow showing detail views for unobserved taxa
13441347
include_unobserved = True
13451348
if self.action == "list":
@@ -1349,12 +1352,12 @@ def get_queryset(self) -> QuerySet:
13491352
qs = qs.prefetch_related(
13501353
Prefetch(
13511354
"occurrences",
1352-
queryset=Occurrence.objects.filter_by_score_threshold(project, self.request).filter(
1353-
self.get_occurrence_filters(project)
1354-
)[:1],
1355+
queryset=Occurrence.objects.filter_by_score_threshold( # type: ignore
1356+
project, self.request
1357+
).filter(self.get_occurrence_filters(project))[:1],
13551358
to_attr="example_occurrences",
13561359
)
1357-
) # type: ignore
1360+
)
13581361
else:
13591362
# Add empty occurrences list to make the response consistent
13601363
qs = qs.annotate(example_occurrences=models.Value([], output_field=models.JSONField()))
@@ -1551,10 +1554,12 @@ def get(self, request):
15511554
"captures_count": SourceImage.objects.filter(deployment__project=project).count(),
15521555
# "detections_count": Detection.objects.filter(occurrence__project=project).count(),
15531556
"occurrences_count": Occurrence.objects.filter_by_score_threshold(project, self.request)
1557+
.filter_by_project_default_taxa(project, self.request)
15541558
.valid()
15551559
.filter(project=project)
15561560
.count(), # type: ignore
15571561
"taxa_count": Occurrence.objects.filter_by_score_threshold(project, self.request)
1562+
.filter_by_project_default_taxa(project, self.request)
15581563
.unique_taxa(project=project)
15591564
.count(), # type: ignore
15601565
}

ami/main/models.py

Lines changed: 139 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -874,10 +874,57 @@ def save(self, update_calculated_fields=True, regroup_async=True, *args, **kwarg
874874
# ami.tasks.model_task.delay("Project", self.project.pk, "update_children_project")
875875

876876

877+
class EventQuerySet(models.QuerySet):
878+
def with_taxa_count(self, project: Project | None = None, request: Request | None = None):
879+
"""
880+
Annotate each event with the number of distinct taxa observed,
881+
filtered by classification threshold and the project's default
882+
include/exclude taxa settings.
883+
"""
884+
if project is None:
885+
return self
886+
887+
classification_threshold = get_default_classification_threshold(project, request)
888+
889+
# Start with a base filter for classification score
890+
filter_q = models.Q(
891+
occurrences__determination_score__gte=classification_threshold,
892+
)
893+
894+
# Apply include/exclude taxa from project defaults
895+
include_taxa = project.default_filters_include_taxa.all()
896+
exclude_taxa = project.default_filters_exclude_taxa.all()
897+
898+
if include_taxa.exists():
899+
include_filter = models.Q(occurrences__determination__in=include_taxa)
900+
for taxon in include_taxa:
901+
include_filter |= models.Q(occurrences__determination__parents_json__contains=[{"id": taxon.pk}])
902+
filter_q &= include_filter
903+
904+
if exclude_taxa.exists():
905+
exclude_filter = models.Q(occurrences__determination__in=exclude_taxa)
906+
for taxon in exclude_taxa:
907+
exclude_filter |= models.Q(occurrences__determination__parents_json__contains=[{"id": taxon.pk}])
908+
filter_q &= ~exclude_filter
909+
910+
return self.annotate(
911+
taxa_count=models.Count(
912+
"occurrences__determination",
913+
distinct=True,
914+
filter=filter_q,
915+
)
916+
)
917+
918+
919+
class EventManager(models.Manager.from_queryset(EventQuerySet)):
920+
pass
921+
922+
877923
@final
878924
class Event(BaseModel):
879925
"""A monitoring session"""
880926

927+
objects: EventManager = EventManager()
881928
group_by = models.CharField(
882929
max_length=255,
883930
db_index=True,
@@ -1499,24 +1546,54 @@ def delete_source_image(sender, instance, **kwargs):
14991546

15001547

15011548
class SourceImageQuerySet(models.QuerySet):
1502-
def with_occurrences_count(self, classification_threshold: float = 0):
1549+
def _build_default_taxa_filter(
1550+
self,
1551+
classification_threshold: float = 0,
1552+
project: Project | None = None,
1553+
) -> Q:
1554+
"""
1555+
Build a reusable Q filter that applies the classification threshold
1556+
and the project's default include/exclude taxa settings.
1557+
"""
1558+
filter_q = Q(detections__occurrence__determination_score__gte=classification_threshold)
1559+
1560+
if not project:
1561+
return filter_q
1562+
1563+
include_taxa = project.default_filters_include_taxa.all()
1564+
exclude_taxa = project.default_filters_exclude_taxa.all()
1565+
1566+
if include_taxa.exists():
1567+
include_q = Q(detections__occurrence__determination__in=include_taxa)
1568+
for taxon in include_taxa:
1569+
include_q |= Q(detections__occurrence__determination__parents_json__contains=[{"id": taxon.pk}])
1570+
filter_q &= include_q
1571+
1572+
if exclude_taxa.exists():
1573+
exclude_q = Q(detections__occurrence__determination__in=exclude_taxa)
1574+
for taxon in exclude_taxa:
1575+
exclude_q |= Q(detections__occurrence__determination__parents_json__contains=[{"id": taxon.pk}])
1576+
filter_q &= ~exclude_q
1577+
1578+
return filter_q
1579+
1580+
def with_occurrences_count(self, classification_threshold: float = 0, project: Project | None = None):
1581+
filter_q = self._build_default_taxa_filter(classification_threshold, project)
15031582
return self.annotate(
15041583
occurrences_count=models.Count(
15051584
"detections__occurrence",
1506-
filter=models.Q(
1507-
detections__occurrence__determination_score__gte=classification_threshold,
1508-
),
1585+
filter=filter_q,
15091586
distinct=True,
15101587
)
15111588
)
15121589

1513-
def with_taxa_count(self, classification_threshold: float = 0):
1590+
def with_taxa_count(self, classification_threshold: float = 0, project: Project | None = None):
1591+
filter_q = self._build_default_taxa_filter(classification_threshold, project)
1592+
15141593
return self.annotate(
15151594
taxa_count=models.Count(
15161595
"detections__occurrence__determination",
1517-
filter=models.Q(
1518-
detections__occurrence__determination_score__gte=classification_threshold,
1519-
),
1596+
filter=filter_q,
15201597
distinct=True,
15211598
)
15221599
)
@@ -2466,6 +2543,33 @@ def filter_by_score_threshold(self, project: Project | None = None, request: Req
24662543
score_threshold = get_default_classification_threshold(project, request)
24672544
return self.filter(determination_score__gte=score_threshold)
24682545

2546+
def filter_by_project_default_taxa(self, project: Project | None = None, request: Request | None = None):
2547+
if project is None:
2548+
return self
2549+
if request is not None:
2550+
apply_defaults = request.query_params.get("apply_defaults", "true").lower()
2551+
if apply_defaults == "false":
2552+
return self
2553+
qs = self
2554+
include_taxa = project.default_filters_include_taxa.all()
2555+
exclude_taxa = project.default_filters_exclude_taxa.all()
2556+
2557+
include_filter = Q()
2558+
if include_taxa.exists():
2559+
include_filter = Q(determination__in=include_taxa)
2560+
for taxon in include_taxa:
2561+
include_filter |= Q(determination__parents_json__contains=[{"id": taxon.pk}])
2562+
qs = qs.filter(include_filter)
2563+
2564+
exclude_filter = Q()
2565+
if exclude_taxa.exists():
2566+
exclude_filter = Q(determination__in=exclude_taxa)
2567+
for taxon in exclude_taxa:
2568+
exclude_filter |= Q(determination__parents_json__contains=[{"id": taxon.pk}])
2569+
qs = qs.exclude(exclude_filter)
2570+
2571+
return qs
2572+
24692573

24702574
class OccurrenceManager(models.Manager.from_queryset(OccurrenceQuerySet)):
24712575
def get_queryset(self):
@@ -2732,6 +2836,33 @@ def with_occurrence_counts(self, project: Project):
27322836

27332837
return qs.annotate(occurrence_count=models.Count("occurrences", distinct=True))
27342838

2839+
def filter_by_project_default_taxa(self, project: Project | None = None, request: Request | None = None):
2840+
"""
2841+
Filter taxa according to a project's default include and exclude settings,
2842+
keeping taxa in the include set along with their descendants
2843+
and removing taxa in the exclude set along with their descendants.
2844+
"""
2845+
if project is None:
2846+
return self
2847+
2848+
qs = self
2849+
include_taxa = project.default_filters_include_taxa.all()
2850+
exclude_taxa = project.default_filters_exclude_taxa.all()
2851+
2852+
if include_taxa.exists():
2853+
include_filter = Q(id__in=include_taxa)
2854+
for taxon in include_taxa:
2855+
include_filter |= Q(parents_json__contains=[{"id": taxon.pk}])
2856+
qs = qs.filter(include_filter)
2857+
2858+
if exclude_taxa.exists():
2859+
exclude_filter = Q(id__in=exclude_taxa)
2860+
for taxon in exclude_taxa:
2861+
exclude_filter |= Q(parents_json__contains=[{"id": taxon.pk}])
2862+
qs = qs.exclude(exclude_filter)
2863+
2864+
return qs
2865+
27352866

27362867
@final
27372868
class TaxonManager(models.Manager.from_queryset(TaxonQuerySet)):

0 commit comments

Comments
 (0)