Skip to content

Commit 378b9dc

Browse files
authored
Remove default classification threshold (#613)
* feat: remove the default classification threshold on all queries * chore: update formatting
1 parent 67e95d0 commit 378b9dc

File tree

4 files changed

+35
-33
lines changed

4 files changed

+35
-33
lines changed

ami/main/api/views.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,7 @@ class SourceImageViewSet(DefaultViewSet):
374374
GET /captures/1/
375375
"""
376376

377-
queryset = (
378-
SourceImage.objects.all()
379-
.with_occurrences_count() # type: ignore
380-
.with_taxa_count()
381-
# .with_detections_count()
382-
)
377+
queryset = SourceImage.objects.all()
383378

384379
serializer_class = SourceImageSerializer
385380
filterset_fields = ["event", "deployment", "deployment__project", "collections"]
@@ -411,6 +406,13 @@ def get_queryset(self) -> QuerySet:
411406
queryset = super().get_queryset()
412407
with_detections_default = False
413408

409+
classification_threshold = get_active_classification_threshold(self.request)
410+
queryset = queryset.with_occurrences_count( # type: ignore
411+
classification_threshold=classification_threshold
412+
).with_taxa_count( # type: ignore
413+
classification_threshold=classification_threshold
414+
)
415+
414416
queryset.select_related(
415417
"event",
416418
"deployment",
@@ -542,8 +544,6 @@ class SourceImageCollectionViewSet(DefaultViewSet):
542544
SourceImageCollection.objects.all()
543545
.with_source_images_count() # type: ignore
544546
.with_source_images_with_detections_count()
545-
.with_occurrences_count()
546-
.with_taxa_count()
547547
.prefetch_related("jobs")
548548
)
549549
serializer_class = SourceImageCollectionSerializer
@@ -559,6 +559,16 @@ class SourceImageCollectionViewSet(DefaultViewSet):
559559
"occurrences_count",
560560
]
561561

562+
def get_queryset(self) -> QuerySet:
563+
classification_threshold = get_active_classification_threshold(self.request)
564+
queryset = (
565+
super()
566+
.get_queryset()
567+
.with_occurrences_count(classification_threshold=classification_threshold) # type: ignore
568+
.with_taxa_count(classification_threshold=classification_threshold)
569+
)
570+
return queryset
571+
562572
@action(detail=True, methods=["post"], name="populate")
563573
def populate(self, request, pk=None):
564574
"""

ami/main/models.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,6 @@ def update_calculated_fields(self, save=False):
580580
self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count()
581581
self.occurrences_count = (
582582
self.occurrences.filter(
583-
determination_score__gte=settings.DEFAULT_CONFIDENCE_THRESHOLD,
584583
event__isnull=False,
585584
)
586585
.distinct()
@@ -589,7 +588,6 @@ def update_calculated_fields(self, save=False):
589588
self.taxa_count = (
590589
Taxon.objects.filter(
591590
occurrences__deployment=self,
592-
occurrences__determination_score__gte=settings.DEFAULT_CONFIDENCE_THRESHOLD,
593591
occurrences__event__isnull=False,
594592
)
595593
.distinct()
@@ -710,12 +708,8 @@ def get_captures_count(self) -> int:
710708
def get_detections_count(self) -> int | None:
711709
return Detection.objects.filter(Q(source_image__event=self)).count()
712710

713-
def get_occurrences_count(self, classification_threshold: int | None = None) -> int:
714-
return (
715-
self.occurrences.distinct()
716-
.filter(determination_score__gte=classification_threshold or settings.DEFAULT_CONFIDENCE_THRESHOLD)
717-
.count()
718-
)
711+
def get_occurrences_count(self, classification_threshold: float = 0) -> int:
712+
return self.occurrences.distinct().filter(determination_score__gte=classification_threshold).count()
719713

720714
def stats(self) -> dict[str, int | None]:
721715
return (
@@ -728,15 +722,15 @@ def stats(self) -> dict[str, int | None]:
728722
)
729723
)
730724

731-
def taxa_count(self, classification_threshold: int | None = None) -> int:
725+
def taxa_count(self, classification_threshold: float = 0) -> int:
732726
# Move this to a pre-calculated field or prefetch_related in the view
733727
# return self.taxa(classification_threshold).count()
734728
return 0
735729

736-
def taxa(self, classification_threshold: int | None = None) -> models.QuerySet["Taxon"]:
730+
def taxa(self, classification_threshold: float = 0) -> models.QuerySet["Taxon"]:
737731
return Taxon.objects.filter(
738732
Q(occurrences__event=self),
739-
occurrences__determination_score__gte=classification_threshold or settings.DEFAULT_CONFIDENCE_THRESHOLD,
733+
occurrences__determination_score__gte=classification_threshold,
740734
).distinct()
741735

742736
def first_capture(self):
@@ -1145,23 +1139,23 @@ def delete_source_image(sender, instance, **kwargs):
11451139

11461140

11471141
class SourceImageQuerySet(models.QuerySet):
1148-
def with_occurrences_count(self):
1142+
def with_occurrences_count(self, classification_threshold: float = 0):
11491143
return self.annotate(
11501144
occurrences_count=models.Count(
11511145
"detections__occurrence",
11521146
filter=models.Q(
1153-
detections__occurrence__determination_score__gte=settings.DEFAULT_CONFIDENCE_THRESHOLD
1147+
detections__occurrence__determination_score__gte=classification_threshold,
11541148
),
11551149
distinct=True,
11561150
)
11571151
)
11581152

1159-
def with_taxa_count(self):
1153+
def with_taxa_count(self, classification_threshold: float = 0):
11601154
return self.annotate(
11611155
taxa_count=models.Count(
11621156
"detections__occurrence__determination",
11631157
filter=models.Q(
1164-
detections__occurrence__determination_score__gte=settings.DEFAULT_CONFIDENCE_THRESHOLD
1158+
detections__occurrence__determination_score__gte=classification_threshold,
11651159
),
11661160
distinct=True,
11671161
)
@@ -2475,7 +2469,7 @@ def occurrence_images(
24752469
self,
24762470
limit: int | None = 10,
24772471
project_id: int | None = None,
2478-
classification_threshold: float | None = None,
2472+
classification_threshold: float = 0,
24792473
) -> list[str]:
24802474
"""
24812475
Return one image from each occurrence of this Taxon.
@@ -2489,8 +2483,6 @@ def occurrence_images(
24892483
Use the request to generate the full media URLs.
24902484
"""
24912485

2492-
classification_threshold = classification_threshold or settings.DEFAULT_CONFIDENCE_THRESHOLD
2493-
24942486
# Retrieve the URLs using a single optimized query
24952487
qs = (
24962488
self.occurrences.prefetch_related(
@@ -2663,22 +2655,25 @@ def with_source_images_with_detections_count(self):
26632655
)
26642656
)
26652657

2666-
def with_occurrences_count(self):
2658+
def with_occurrences_count(self, classification_threshold: float = 0):
26672659
return self.annotate(
26682660
occurrences_count=models.Count(
26692661
"images__detections__occurrence",
26702662
filter=models.Q(
2671-
images__detections__occurrence__determination_score__gte=settings.DEFAULT_CONFIDENCE_THRESHOLD
2663+
images__detections__occurrence__determination_score__gte=classification_threshold,
26722664
),
26732665
distinct=True,
26742666
)
26752667
)
26762668

2677-
def with_taxa_count(self):
2669+
def with_taxa_count(self, classification_threshold: float = 0):
26782670
return self.annotate(
26792671
taxa_count=models.Count(
26802672
"images__detections__occurrence__determination",
26812673
distinct=True,
2674+
filter=models.Q(
2675+
images__detections__occurrence__determination_score__gte=classification_threshold,
2676+
),
26822677
)
26832678
)
26842679

ami/utils/requests.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from django.conf import settings
21
from django.forms import FloatField
32
from rest_framework.request import Request
43

@@ -10,5 +9,5 @@ def get_active_classification_threshold(request: Request) -> float:
109
if classification_threshold is not None:
1110
classification_threshold = FloatField(required=False).clean(classification_threshold)
1211
else:
13-
classification_threshold = settings.DEFAULT_CONFIDENCE_THRESHOLD
12+
classification_threshold = 0
1413
return classification_threshold

config/settings/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,6 @@
369369
# Your stuff...
370370
# ------------------------------------------------------------------------------
371371

372-
DEFAULT_CONFIDENCE_THRESHOLD = env.float("DEFAULT_CONFIDENCE_THRESHOLD", default=0.6) # type: ignore[no-untyped-call]
373-
374372
S3_TEST_ENDPOINT = env("MINIO_ENDPOINT", default="http://minio:9000") # type: ignore[no-untyped-call]
375373
S3_TEST_KEY = env("MINIO_ROOT_USER", default=None) # type: ignore[no-untyped-call]
376374
S3_TEST_SECRET = env("MINIO_ROOT_PASSWORD", default=None) # type: ignore[no-untyped-call]

0 commit comments

Comments
 (0)