diff --git a/ami/main/api/views.py b/ami/main/api/views.py index c5306c892..2b0ea4755 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -30,7 +30,7 @@ from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer from ami.base.views import ProjectMixin from ami.main.api.serializers import TagSerializer -from ami.utils.requests import get_active_classification_threshold, project_id_doc_param +from ami.utils.requests import get_default_classification_threshold, project_id_doc_param from ami.utils.storages import ConnectionTestResult from ..models import ( @@ -336,7 +336,9 @@ def get_queryset(self) -> QuerySet: "occurrences__determination", distinct=True, filter=models.Q( - occurrences__determination_score__gte=get_active_classification_threshold(self.request), + occurrences__determination_score__gte=get_default_classification_threshold( + project, self.request + ), ), ), ) @@ -493,9 +495,10 @@ def get_serializer_context(self): def get_queryset(self) -> QuerySet: queryset = super().get_queryset() + project = self.get_active_project() with_detections_default = False - classification_threshold = get_active_classification_threshold(self.request) + classification_threshold = get_default_classification_threshold(project, self.request) queryset = queryset.with_occurrences_count( # type: ignore classification_threshold=classification_threshold ).with_taxa_count( # type: ignore @@ -652,9 +655,9 @@ class SourceImageCollectionViewSet(DefaultViewSet, ProjectMixin): ] def get_queryset(self) -> QuerySet: - classification_threshold = get_active_classification_threshold(self.request) query_set: QuerySet = super().get_queryset() project = self.get_active_project() + classification_threshold = get_default_classification_threshold(project, self.request) if project: query_set = query_set.filter(project=project) queryset = query_set.with_occurrences_count( # type: ignore @@ -1052,11 +1055,6 @@ def filter_queryset(self, request, queryset, view): return queryset -OccurrenceDeterminationScoreFilter = ThresholdFilter.create( - query_param="classification_threshold", filter_param="determination_score" -) - - class OccurrenceViewSet(DefaultViewSet, ProjectMixin): """ API endpoint that allows occurrences to be viewed or edited. @@ -1074,7 +1072,6 @@ class OccurrenceViewSet(DefaultViewSet, ProjectMixin): OccurrenceVerified, OccurrenceVerifiedByMeFilter, OccurrenceTaxaListFilter, - OccurrenceDeterminationScoreFilter, ] filterset_fields = [ "event", @@ -1118,7 +1115,7 @@ def get_queryset(self) -> QuerySet["Occurrence"]: ) qs = qs.with_detections_count().with_timestamps() # type: ignore qs = qs.with_identifications() # type: ignore - + qs = qs.filter_by_score_threshold(project, self.request) # type: ignore if self.action != "list": qs = qs.prefetch_related( Prefetch( @@ -1360,7 +1357,9 @@ def get_queryset(self) -> QuerySet: qs = qs.prefetch_related( Prefetch( "occurrences", - queryset=Occurrence.objects.filter(self.get_occurrence_filters(project))[:1], + queryset=Occurrence.objects.filter_by_score_threshold( # type: ignore + project, self.request + ).filter(self.get_occurrence_filters(project))[:1], to_attr="example_occurrences", ) ) @@ -1385,6 +1384,7 @@ def get_taxa_observed(self, qs: QuerySet, project: Project, include_unobserved=F occurrence_filters, determination_id=models.OuterRef("id"), ) + .filter_by_score_threshold(project, self.request) # type: ignore .values("determination_id") .annotate(count=models.Count("id")) .values("count"), @@ -1427,6 +1427,8 @@ def get_taxa_observed(self, qs: QuerySet, project: Project, include_unobserved=F Occurrence.objects.filter( occurrence_filters, determination_id=models.OuterRef("id"), + ).filter_by_score_threshold( # type: ignore + project, self.request ), ) ) @@ -1564,13 +1566,15 @@ def get(self, request): "captures_count": SourceImage.objects.visible_for_user(user) # type: ignore .filter(deployment__project=project) .count(), - "occurrences_count": Occurrence.objects.valid() + "occurrences_count": Occurrence.objects.valid() # type: ignore .visible_for_user(user) .filter(project=project) - .count(), # type: ignore - "taxa_count": Occurrence.objects.visible_for_user(user) + .filter_by_score_threshold(project, self.request) # type: ignore + .count(), + "taxa_count": Occurrence.objects.visible_for_user(user) # type: ignore + .filter_by_score_threshold(project, self.request) .unique_taxa(project=project) - .count(), # type: ignore + .count(), } else: data = { diff --git a/ami/main/migrations/0075_auto_20250922_0130.py b/ami/main/migrations/0075_auto_20250922_0130.py new file mode 100644 index 000000000..05a6eee3a --- /dev/null +++ b/ami/main/migrations/0075_auto_20250922_0130.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.10 on 2025-09-22 01:30 + +from django.db import migrations + + +def refresh_deployment_counts(apps, schema_editor): + Deployment = apps.get_model("main", "Deployment") + for dep in Deployment.objects.all(): + dep.save() # triggers save logic and recalculates counts + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0074_taxon_cover_image_credit_taxon_cover_image_url_and_more"), + ] + + operations = [ + migrations.RunPython(refresh_deployment_counts, reverse_code=migrations.RunPython.noop), + ] diff --git a/ami/main/models.py b/ami/main/models.py index 7e3935609..bdad5dca5 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -18,7 +18,7 @@ from django.core.exceptions import ValidationError from django.core.files.storage import default_storage from django.db import IntegrityError, models, transaction -from django.db.models import Q +from django.db.models import Exists, OuterRef, Q from django.db.models.fields.files import ImageFieldFile from django.db.models.signals import pre_delete from django.dispatch import receiver @@ -26,6 +26,7 @@ from django.utils import timezone from django_pydantic_field import SchemaField from guardian.shortcuts import get_perms +from rest_framework.request import Request import ami.tasks import ami.utils @@ -36,6 +37,7 @@ from ami.ml.schemas import BoundingBox from ami.users.models import User from ami.utils.media import calculate_file_checksum, extract_timestamp +from ami.utils.requests import get_default_classification_threshold from ami.utils.schemas import OrderedEnum if typing.TYPE_CHECKING: @@ -286,10 +288,26 @@ def summary_data(self): return plots - def save(self, *args, **kwargs): + def update_related_calculated_fields(self): + """ + Update calculated fields for all related events and deployments. + """ + # Update events + for event in self.events.all(): + event.update_calculated_fields(save=True) + + # Update deployments + for deployment in self.deployments.all(): + deployment.update_calculated_fields(save=True) + + def save(self, *args, update_related_calculated_fields: bool = True, **kwargs): super().save(*args, **kwargs) # Add owner to members self.ensure_owner_membership() + # Update calculated fields including filtered occurrence counts + # and taxa counts for related deployments and events + if update_related_calculated_fields: + self.update_related_calculated_fields() class Permissions: """CRUD Permission names follow the convention: `create_`, `update_`, @@ -833,22 +851,15 @@ def update_calculated_fields(self, save=False): self.events_count = self.events.count() self.captures_count = self.data_source_total_files or self.captures.count() self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count() - self.occurrences_count = ( - self.occurrences.filter( - event__isnull=False, - ) - .distinct() - .count() - ) - self.taxa_count = ( - Taxon.objects.filter( - occurrences__deployment=self, - occurrences__event__isnull=False, - ) - .distinct() - .count() + occ_qs = self.occurrences.filter(event__isnull=False).filter_by_score_threshold( # type: ignore + project=self.project, + request=None, ) + self.occurrences_count = occ_qs.distinct().count() + + self.taxa_count = Taxon.objects.filter(id__in=occ_qs.values("determination_id")).distinct().count() + self.first_capture_timestamp, self.last_capture_timestamp = self.get_first_and_last_timestamps() if save: @@ -2459,6 +2470,12 @@ def unique_taxa(self, project: Project | None = None): ) return qs + def filter_by_score_threshold(self, project: Project | None = None, request: Request | None = None): + if project is None: + return self + score_threshold = get_default_classification_threshold(project, request) + return self.filter(determination_score__gte=score_threshold) + class OccurrenceManager(models.Manager.from_queryset(OccurrenceQuerySet)): def get_queryset(self): @@ -2725,6 +2742,34 @@ def with_occurrence_counts(self, project: Project): return qs.annotate(occurrence_count=models.Count("occurrences", distinct=True)) + def visible_for_user(self, user: User | AnonymousUser): + if user.is_superuser: + return self + + is_anonymous = isinstance(user, AnonymousUser) + + # Visible projects + project_qs = Project.objects.all() + if is_anonymous: + project_qs = project_qs.filter(draft=False) + else: + project_qs = project_qs.filter(Q(draft=False) | Q(owner=user) | Q(members=user)) + + # Taxa explicitly linked to visible projects + direct_taxa = self.filter(projects__in=project_qs) + + # Taxa with at least one occurrence in visible projects + occurrence_taxa = self.filter( + Exists( + Occurrence.objects.filter( + project__in=project_qs, + determination_id=OuterRef("id"), + ) + ) + ) + + return (direct_taxa | occurrence_taxa).distinct() + @final class TaxonManager(models.Manager.from_queryset(TaxonQuerySet)): diff --git a/ami/main/tests.py b/ami/main/tests.py index 63cdf8491..a4f80c089 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -2280,3 +2280,358 @@ def test_summary_counts(self): ) logger.info("All exact count validations passed") + + +class TestProjectDefaultThresholdFilter(APITestCase): + """API tests for default score threshold filtering""" + + def setUp(self): + # Create project, deployment, and test data + self.project, self.deployment = setup_test_project(reuse=False) + taxa_list = create_taxa(self.project) + taxa = list(taxa_list.taxa.all()) + low_taxon = taxa[0] + high_taxon = taxa[1] + create_captures(deployment=self.deployment, num_nights=1, images_per_night=3) + + # Create multiple low and high determination score occurrences + create_occurrences(deployment=self.deployment, num=3, determination_score=0.3, taxon=low_taxon) + create_occurrences(deployment=self.deployment, num=3, determination_score=0.9, taxon=high_taxon) + + self.low_occurrences = Occurrence.objects.filter(deployment=self.deployment, determination_score=0.3) + self.high_occurrences = Occurrence.objects.filter(deployment=self.deployment, determination_score=0.9) + + # Project default threshold + self.default_threshold = 0.6 + self.project.default_filters_score_threshold = self.default_threshold + self.project.save() + + # Auth user + self.user = User.objects.create_user(email="tester@insectai.org", is_staff=False, is_superuser=False) + self.client.force_authenticate(user=self.user) + + self.url = f"/api/v2/occurrences/?project_id={self.project.pk}" + self.url_taxa = f"/api/v2/taxa/?project_id={self.project.pk}" + + # OccurrenceViewSet tests + def test_occurrences_respect_project_threshold(self): + """Occurrences below project threshold should be filtered out""" + res = self.client.get(self.url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + ids = {o["id"] for o in res.data["results"]} + + # High-scoring occurrences should remain + for occ in self.high_occurrences: + self.assertIn(occ.id, ids) + # Low-scoring occurrences should be excluded + for occ in self.low_occurrences: + self.assertNotIn(occ.id, ids) + + def test_apply_defaults_false_bypasses_threshold(self): + """apply_defaults=false should allow explicit classification_threshold to override project default""" + res = self.client.get(self.url + "&apply_defaults=false&classification_threshold=0.2") + self.assertEqual(res.status_code, status.HTTP_200_OK) + ids = {o["id"] for o in res.data["results"]} + # Both sets should be included with threshold=0.2 + for occ in list(self.high_occurrences) + list(self.low_occurrences): + self.assertIn(occ.id, ids) + + def test_query_threshold_ignored_when_defaults_applied(self): + """classification_threshold param is ignored if apply_defaults is not false""" + res = self.client.get(self.url + "&classification_threshold=0.1") + self.assertEqual(res.status_code, status.HTTP_200_OK) + ids = {o["id"] for o in res.data["results"]} + # Still should apply project default (0.5) + for occ in self.high_occurrences: + self.assertIn(occ.id, ids) + for occ in self.low_occurrences: + self.assertNotIn(occ.id, ids) + + def test_no_project_id_returns_all(self): + """Without project_id, threshold falls back to 0.0 and returns all occurrences""" + url = "/api/v2/occurrences/" + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + ids = {o["id"] for o in res.data["results"]} + # All occurrences should appear + for occ in list(self.high_occurrences) + list(self.low_occurrences): + self.assertIn(occ.id, ids) + + def test_retrieve_occurrence_respects_threshold(self): + """Detail retrieval should 404 if occurrence is filtered out by threshold""" + low_occ = self.low_occurrences[0] + detail_url = f"/api/v2/occurrences/{low_occ.id}/?project_id={self.project.pk}" + res = self.client.get(detail_url) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) + + high_occ = self.high_occurrences[0] + detail_url = f"/api/v2/occurrences/{high_occ.id}/?project_id={self.project.pk}" + res = self.client.get(detail_url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + # TaxonViewSet tests + def test_taxa_respect_project_threshold(self): + """Taxa with only low-score occurrences should be excluded""" + res = self.client.get(self.url_taxa) + self.assertEqual(res.status_code, status.HTTP_200_OK) + names = {t["name"] for t in res.data["results"]} + + for occ in self.high_occurrences: + self.assertIn(occ.determination.name, names) + for occ in self.low_occurrences: + self.assertNotIn(occ.determination.name, names) + + def test_apply_defaults_false_bypasses_threshold_taxa(self): + """apply_defaults=false should allow low-score taxa to appear""" + res = self.client.get(self.url_taxa + "&apply_defaults=false&classification_threshold=0.2") + self.assertEqual(res.status_code, status.HTTP_200_OK) + names = {t["name"] for t in res.data["results"]} + + for occ in list(self.high_occurrences) + list(self.low_occurrences): + self.assertIn(occ.determination.name, names) + + def test_query_threshold_ignored_when_defaults_applied_taxa(self): + """classification_threshold is ignored when defaults apply""" + res = self.client.get(self.url_taxa + "&classification_threshold=0.1") + self.assertEqual(res.status_code, status.HTTP_200_OK) + names = {t["name"] for t in res.data["results"]} + + for occ in self.high_occurrences: + self.assertIn(occ.determination.name, names) + for occ in self.low_occurrences: + self.assertNotIn(occ.determination.name, names) + + def test_include_unobserved_true_returns_unobserved_taxa(self): + """include_unobserved=true should return taxa even without valid occurrences""" + res = self.client.get(self.url_taxa + "&include_unobserved=true") + self.assertEqual(res.status_code, status.HTTP_200_OK) + # There should be more taxa than just the ones tied to high occurrences + self.assertGreater(len(res.data["results"]), self.high_occurrences.count()) + + def test_taxon_detail_example_occurrences_respects_threshold(self): + """Detail view should prefetch only above-threshold occurrences""" + taxon = self.high_occurrences.first().determination + detail_url = f"/api/v2/taxa/{taxon.id}/?project_id={self.project.pk}" + res = self.client.get(detail_url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + example_occ = res.data.get("example_occurrences", []) + self.assertTrue(all(o["determination_score"] >= 0.6 for o in example_occ)) + + def test_taxa_count_matches_summary_with_threshold(self): + """Taxa count from taxa endpoint should match taxa_count in summary when defaults applied""" + # Get taxa list + res_taxa = self.client.get(self.url_taxa) + self.assertEqual(res_taxa.status_code, status.HTTP_200_OK) + taxa_count = len(res_taxa.data["results"]) + + # Get summary (global status summary, filtered by project_id) + url_summary = f"/api/v2/status/summary/?project_id={self.project.pk}" + res_summary = self.client.get(url_summary) + self.assertEqual(res_summary.status_code, status.HTTP_200_OK) + + summary_taxa_count = res_summary.data["taxa_count"] + + self.assertEqual( + taxa_count, + summary_taxa_count, + f"Mismatch: taxa endpoint returned {taxa_count}, summary returned {summary_taxa_count}", + ) + + # SourceImageViewSet tests + def test_source_image_counts_respect_threshold(self): + """occurrences_count and taxa_count should exclude low-score occurrences (per-capture assertions).""" + url = f"/api/v2/captures/?project_id={self.project.pk}" + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + for capture in res.data["results"]: + cap_id = capture["id"] + + # All occurrences linked to this capture via detections + cap_occs = Occurrence.objects.filter( + detections__source_image_id=cap_id, + deployment=self.deployment, + ).distinct() + + cap_high_occs = cap_occs.filter(determination_score__gte=self.default_threshold) + + # Expected counts for this capture under default threshold + expected_occurrences_count = cap_high_occs.count() + expected_taxa_count = cap_high_occs.values("determination_id").distinct().count() + + # Exact assertions against the API’s annotated fields + self.assertEqual(capture["occurrences_count"], expected_occurrences_count) + self.assertEqual(capture["taxa_count"], expected_taxa_count) + + # If capture only has low-score occurrences, both counts must be zero + if cap_occs.exists() and not cap_high_occs.exists(): + self.assertEqual(capture["occurrences_count"], 0) + self.assertEqual(capture["taxa_count"], 0) + + def _make_collection_with_some_images(self, name="Test Manual Source Image Collection"): + """Create a manual collection including a few of this deployment's captures using populate_sample().""" + images = list(SourceImage.objects.filter(deployment=self.deployment).order_by("id")) + self.assertGreaterEqual(len(images), 3, "Need at least 3 source images from setup") + + collection = SourceImageCollection.objects.create( + name=name, + project=self.project, + method="manual", + kwargs={"image_ids": [img.pk for img in images[:3]]}, # deterministic subset + ) + collection.save() + collection.populate_sample() + return collection + + def _expected_counts_for_collection(self, collection, threshold: float) -> tuple[int, int]: + """Return (occurrences_count, taxa_count) for a collection under a given threshold.""" + coll_occs = Occurrence.objects.filter( + detections__source_image__collections=collection, + deployment=self.deployment, + ).distinct() + coll_high = coll_occs.filter(determination_score__gte=threshold) + occ_count = coll_high.count() + taxa_count = coll_high.values("determination_id").distinct().count() + return occ_count, taxa_count + + # SourceImageCollectionViewSet tests + def test_collections_counts_respect_threshold(self): + """occurrences_count and taxa_count on collections should exclude low-score occurrences.""" + collection = self._make_collection_with_some_images() + + url = f"/api/v2/captures/collections/?project_id={self.project.pk}" + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + row = next((r for r in res.data["results"] if r["id"] == collection.id), None) + self.assertIsNotNone(row, "Expected the created collection in list response") + + expected_occ, expected_taxa = self._expected_counts_for_collection(collection, self.default_threshold) + self.assertEqual(row["occurrences_count"], expected_occ) + self.assertEqual(row["taxa_count"], expected_taxa) + + def _expected_event_taxa_count(self, event, threshold: float) -> int: + """Distinct determinations among this event's occurrences at/above threshold.""" + return ( + Occurrence.objects.filter( + event=event, + determination_score__gte=threshold, + ) + .values("determination_id") + .distinct() + .count() + ) + + # EventViewSet tests + def test_event_taxa_count_respects_threshold(self): + create_captures(deployment=self.deployment, num_nights=3, images_per_night=3) + group_images_into_events(deployment=self.deployment) + + url = f"/api/v2/events/?project_id={self.project.pk}" + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + expected = { + e.pk: self._expected_event_taxa_count(e, self.default_threshold) + for e in Event.objects.filter(deployment__project=self.project) + } + + for row in res.data["results"]: + self.assertEqual(row["taxa_count"], expected[row["id"]]) + + # SummaryView tests + def test_summary_counts_respect_project_threshold(self): + """Summary should apply project default threshold to occurrences_count and taxa_count.""" + url = f"/api/v2/status/summary/?project_id={self.project.pk}" + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + expected_occurrences = ( + Occurrence.objects.valid() + .filter(project=self.project, determination_score__gte=self.default_threshold) + .count() + ) + expected_taxa = ( + Occurrence.objects.filter( + project=self.project, + determination_score__gte=self.default_threshold, + ) + .values("determination_id") + .distinct() + .count() + ) + + self.assertEqual(res.data["occurrences_count"], expected_occurrences) + self.assertEqual(res.data["taxa_count"], expected_taxa) + + # DeploymentViewSet tests + def test_deployment_counts_respect_threshold(self): + """occurrences_count and taxa_count on deployments should exclude low-score occurrences.""" + # Call the save() method to refresh counts + for dep in Deployment.objects.all(): + dep.save() + url = f"/api/v2/deployments/?project_id={self.project.pk}" + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + for row in res.data["results"]: + dep_id = row["id"] + dep = Deployment.objects.get(pk=dep_id) + + # All occurrences for this deployment + dep_occs = Occurrence.objects.filter(deployment=dep).distinct() + dep_high_occs = dep_occs.filter(determination_score__gte=self.default_threshold) + + expected_occurrences_count = dep_high_occs.count() + expected_taxa_count = dep_high_occs.values("determination_id").distinct().count() + + # Assert the API matches expected counts + self.assertEqual(row["occurrences_count"], expected_occurrences_count) + self.assertEqual(row["taxa_count"], expected_taxa_count) + + # If deployment only has low-score occurrences, both counts must be zero + if dep_occs.exists() and not dep_high_occs.exists(): + self.assertEqual(row["occurrences_count"], 0) + self.assertEqual(row["taxa_count"], 0) + + def test_taxa_include_occurrence_determinations_not_directly_linked(self): + """ + Taxa should still appear in taxa list and summary if they come from + determinations of occurrences in the project, even when those taxa are + not directly linked to the project via the M2M field. + """ + # Clear existing taxa and occurrences for a clean slate + self.project.taxa.clear() + Occurrence.objects.filter(project=self.project).delete() + # Create a new taxon not linked to the project + outside_taxon = Taxon.objects.create(name="OutsideTaxon") + + # Create occurrences in this project with that taxon as determination + create_occurrences( + deployment=self.deployment, + num=2, + determination_score=0.9, + taxon=outside_taxon, + ) + + # Confirm taxon is not directly associated with the project + self.assertFalse(self.project in outside_taxon.projects.all()) + + # Taxa endpoint should include the taxon (because of occurrences) + res_taxa = self.client.get(self.url_taxa) + self.assertEqual(res_taxa.status_code, status.HTTP_200_OK) + taxa_names = {t["name"] for t in res_taxa.data["results"]} + self.assertIn(outside_taxon.name, taxa_names) + + # Summary should also count it + url_summary = f"/api/v2/status/summary/?project_id={self.project.pk}" + res_summary = self.client.get(url_summary) + self.assertEqual(res_summary.status_code, status.HTTP_200_OK) + summary_taxa_count = res_summary.data["taxa_count"] + + taxa_count = len(res_taxa.data["results"]) + self.assertEqual( + taxa_count, + summary_taxa_count, + f"Mismatch with outside taxon: taxa endpoint returned {taxa_count}, summary {summary_taxa_count}", + ) diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py index 689a9ecb2..398a27605 100644 --- a/ami/tests/fixtures/main.py +++ b/ami/tests/fixtures/main.py @@ -360,17 +360,20 @@ def create_occurrences( deployment: Deployment, num: int = 6, taxon: Taxon | None = None, + determination_score: float = 0.9, ): # Get all source images for the deployment that have an event source_images = list(SourceImage.objects.filter(deployment=deployment)) if not source_images: raise ValueError("No source images with events found for deployment") - # Get taxon if not provided + # Get a random taxon if not provided if not taxon: - taxon = Taxon.objects.filter(projects=deployment.project).order_by("?").first() - if not taxon: + taxa_qs = Taxon.objects.filter(projects=deployment.project) + count = taxa_qs.count() + if count == 0: raise ValueError("No taxa found for project") + taxon = taxa_qs[random.randint(0, count - 1)] # Create occurrences evenly distributed across all source images for i in range(num): @@ -385,7 +388,7 @@ def create_occurrences( detection.classifications.create( taxon=taxon, - score=0.9, + score=determination_score, timestamp=datetime.datetime.now(), ) occurrence = detection.associate_new_occurrence() diff --git a/ami/utils/requests.py b/ami/utils/requests.py index e6a4026c3..b63416674 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -56,6 +56,32 @@ def get_active_classification_threshold(request: Request) -> float: return classification_threshold +def get_default_classification_threshold(project, request: Request | None = None) -> float: + """ + Get the classification threshold from project settings by default, + or from request query parameters if `apply_defaults=false` is set in the request. + + Args: + project: A Project instance. + request: The incoming request object (optional). + + Returns: + The classification threshold value from project settings by default, + or from request if `apply_defaults=false` is provided. + """ + # If request exists and apply_defaults is explicitly false, get from request + if request is not None: + apply_defaults = request.query_params.get("apply_defaults", "true").lower() + if apply_defaults == "false": + return get_active_classification_threshold(request) + + # Otherwise, get from project + if project is None: + return 0.0 + + return getattr(project, "default_filters_score_threshold", 0.0) or 0.0 + + project_id_doc_param = OpenApiParameter( name="project_id", description="Filter by project ID", diff --git a/ui/src/components/filtering/default-filter-control.tsx b/ui/src/components/filtering/default-filter-control.tsx new file mode 100644 index 000000000..72813b7b5 --- /dev/null +++ b/ui/src/components/filtering/default-filter-control.tsx @@ -0,0 +1,84 @@ +import { + FormActions, + FormRow, + FormSection, +} from 'components/form/layout/layout' +import { useProjectDetails } from 'data-services/hooks/projects/useProjectDetails' +import { ProjectDetails } from 'data-services/models/project-details' +import { + IconButton, + IconButtonTheme, +} from 'design-system/components/icon-button/icon-button' +import { IconType } from 'design-system/components/icon/icon' +import { InputValue } from 'design-system/components/input/input' +import { ChevronRightIcon } from 'lucide-react' +import { buttonVariants, Popover } from 'nova-ui-kit' +import { Link, useParams } from 'react-router-dom' +import { APP_ROUTES } from 'utils/constants' +import { STRING, translate } from 'utils/language' + +export const DefaultFiltersControl = () => { + const { projectId } = useParams() + const { project } = useProjectDetails(projectId as string, true) + + return ( +
+
+ + {translate(STRING.NAV_ITEM_DEFAULT_FILTERS)} + + {project ? : null} +
+
+ ) +} + +const InfoPopover = ({ project }: { project: ProjectDetails }) => ( + + + + + + + + + + + taxon.name) + .join(', ')} + /> + taxon.name) + .join(', ')} + /> + + + + + Configure + + + + + +) diff --git a/ui/src/components/filtering/filter-control.tsx b/ui/src/components/filtering/filter-control.tsx index 6def74a19..0d6ecdd79 100644 --- a/ui/src/components/filtering/filter-control.tsx +++ b/ui/src/components/filtering/filter-control.tsx @@ -1,4 +1,4 @@ -import { X } from 'lucide-react' +import { XIcon } from 'lucide-react' import { Button } from 'nova-ui-kit' import { useFilters } from 'utils/useFilters' import { AlgorithmFilter, NotAlgorithmFilter } from './filters/algorithm-filter' @@ -7,7 +7,6 @@ import { CollectionFilter } from './filters/collection-filter' import { DateFilter } from './filters/date-filter' import { ImageFilter } from './filters/image-filter' import { PipelineFilter } from './filters/pipeline-filter' -import { ScoreFilter } from './filters/score-filter' import { SessionFilter } from './filters/session-filter' import { StationFilter } from './filters/station-filter' import { StatusFilter } from './filters/status-filter' @@ -23,8 +22,6 @@ const ComponentMap: { [key: string]: (props: FilterProps) => JSX.Element } = { algorithm: AlgorithmFilter, - best_determination_score: ScoreFilter, - classification_threshold: ScoreFilter, collection: CollectionFilter, collections: CollectionFilter, date_end: DateFilter, @@ -87,12 +84,12 @@ export const FilterControl = ({ /> {clearable && filter.value && ( )} diff --git a/ui/src/components/filtering/filters/score-filter.tsx b/ui/src/components/filtering/filters/score-filter.tsx deleted file mode 100644 index 3aebdd45b..000000000 --- a/ui/src/components/filtering/filters/score-filter.tsx +++ /dev/null @@ -1,38 +0,0 @@ -import { Slider } from 'nova-ui-kit' -import { useEffect, useState } from 'react' -import { useUserPreferences } from 'utils/userPreferences/userPreferencesContext' -import { FilterProps } from './types' - -export const ScoreFilter = ({ value, onAdd }: FilterProps) => { - const { userPreferences, setUserPreferences } = useUserPreferences() - const [displayValue, setDisplayValue] = useState( - userPreferences.scoreThreshold - ) - - useEffect(() => { - if (value?.length) { - setDisplayValue(Number(value)) - } - }, [value]) - - return ( -
- setDisplayValue(value)} - onValueCommit={([value]) => { - setDisplayValue(value) - onAdd(`${value}`) - setUserPreferences({ ...userPreferences, scoreThreshold: value }) - }} - /> - - {displayValue} - -
- ) -} diff --git a/ui/src/pages/occurrences/occurrences.tsx b/ui/src/pages/occurrences/occurrences.tsx index a8053c7af..8f222b832 100644 --- a/ui/src/pages/occurrences/occurrences.tsx +++ b/ui/src/pages/occurrences/occurrences.tsx @@ -1,3 +1,4 @@ +import { DefaultFiltersControl } from 'components/filtering/default-filter-control' import { FilterControl } from 'components/filtering/filter-control' import { FilterSection } from 'components/filtering/filter-section' import { someActive } from 'components/filtering/utils' @@ -28,7 +29,6 @@ import { useColumnSettings } from 'utils/useColumnSettings' import { useFilters } from 'utils/useFilters' import { usePagination } from 'utils/usePagination' import { useUser } from 'utils/user/userContext' -import { useUserPreferences } from 'utils/userPreferences/userPreferencesContext' import { useSelectedView } from 'utils/useSelectedView' import { useSort } from 'utils/useSort' import { OccurrenceActions } from './occurrence-actions' @@ -38,7 +38,6 @@ import { OccurrenceNavigation } from './occurrence-navigation' export const Occurrences = () => { const { user } = useUser() - const { userPreferences } = useUserPreferences() const { projectId, id } = useParams() const { columnSettings, setColumnSettings } = useColumnSettings( 'occurrences', @@ -59,9 +58,7 @@ export const Occurrences = () => { order: 'desc', }) const { pagination, setPage } = usePagination() - const { activeFilters, filters } = useFilters({ - classification_threshold: `${userPreferences.scoreThreshold}`, - }) + const { activeFilters, filters } = useFilters() const { occurrences, total, isLoading, isFetching, error } = useOccurrences({ projectId, pagination, @@ -94,15 +91,13 @@ export const Occurrences = () => { - - {taxaLists.length > 0 && ( )} - {user.loggedIn && } + { activeFilters )} > + + diff --git a/ui/src/pages/project-details/default-filters-form.tsx b/ui/src/pages/project-details/default-filters-form.tsx index f96c33cc4..a6e3db319 100644 --- a/ui/src/pages/project-details/default-filters-form.tsx +++ b/ui/src/pages/project-details/default-filters-form.tsx @@ -111,7 +111,12 @@ export const DefaultFiltersForm = ({ /> - + { const { updateProjectSettings, isLoading, isSuccess, error } = useUpdateProjectSettings(project.id) - const canView = project.canUpdate && project.featureFlags.default_filters + const canView = project.canUpdate useEffect(() => { if (!canView) { diff --git a/ui/src/pages/project/sidebar/useSidebarSections.tsx b/ui/src/pages/project/sidebar/useSidebarSections.tsx index ed44ba31f..8336c5a00 100644 --- a/ui/src/pages/project/sidebar/useSidebarSections.tsx +++ b/ui/src/pages/project/sidebar/useSidebarSections.tsx @@ -92,7 +92,7 @@ const getSidebarSections = ( }, ] : []), - ...(project.canUpdate && project.featureFlags.default_filters + ...(project.canUpdate ? [ { id: 'default-filters', diff --git a/ui/src/pages/session-details/playback/playback.tsx b/ui/src/pages/session-details/playback/playback.tsx index 8bc07ca7b..7aeedcb5b 100644 --- a/ui/src/pages/session-details/playback/playback.tsx +++ b/ui/src/pages/session-details/playback/playback.tsx @@ -6,27 +6,19 @@ import { Checkbox, CheckboxTheme, } from 'design-system/components/checkbox/checkbox' -import { useEffect, useMemo, useState } from 'react' -import { STRING, translate } from 'utils/language' -import { useUserPreferences } from 'utils/userPreferences/userPreferencesContext' +import { useEffect, useState } from 'react' import { ActivityPlot } from './activity-plot/lazy-activity-plot' import { CaptureDetails } from './capture-details/capture-details' import { CaptureNavigation } from './capture-navigation/capture-navigation' import { Frame } from './frame/frame' import styles from './playback.module.scss' import { SessionCapturesSlider } from './session-captures-slider/session-captures-slider' -import { ThresholdSlider } from './threshold-slider/threshold-slider' import { useActiveCaptureId } from './useActiveCapture' export const Playback = ({ session }: { session: SessionDetails }) => { - const { - userPreferences: { scoreThreshold }, - } = useUserPreferences() const { timeline = [] } = useSessionTimeline(session.id) const [poll, setPoll] = useState(false) const [showDetections, setShowDetections] = useState(true) - const [showDetectionsBelowThreshold, setShowDetectionsBelowThreshold] = - useState(false) const [snapToDetections, setSnapToDetections] = useState( session.numDetections ? true : false ) @@ -47,19 +39,7 @@ export const Playback = ({ session }: { session: SessionDetails }) => { } }, [activeCapture]) - const detections = useMemo(() => { - if (!activeCapture?.detections) { - return [] - } - - if (showDetectionsBelowThreshold) { - return activeCapture.detections - } - - return activeCapture.detections.filter( - (detection) => detection.score >= scoreThreshold - ) - }, [activeCapture, scoreThreshold, showDetectionsBelowThreshold]) + const detections = activeCapture?.detections ?? [] if (!session.firstCapture) { return null @@ -80,20 +60,7 @@ export const Playback = ({ session }: { session: SessionDetails }) => { )}
View settings -
- - {translate(STRING.FIELD_LABEL_SCORE_THRESHOLD)} - - -
Preferences - { - const [active, setActive] = useState(false) - const { userPreferences, setUserPreferences } = useUserPreferences() - const [displayThreshold, setDisplayThreshold] = useState( - userPreferences.scoreThreshold - ) - - const onValueCommit = (value: number) => { - setDisplayThreshold(value) - setUserPreferences({ - ...userPreferences, - scoreThreshold: value, - }) - } - - return ( -
- setDisplayThreshold(value)} - onValueCommit={([value]) => onValueCommit(value)} - onPointerDown={() => setActive(true)} - onPointerUp={() => setActive(false)} - onPointerLeave={() => { - if (active) { - onValueCommit(displayThreshold) - } - }} - /> - {displayThreshold} -
- ) -} diff --git a/ui/src/pages/species/species.tsx b/ui/src/pages/species/species.tsx index 08d3d519b..fe65fd8a7 100644 --- a/ui/src/pages/species/species.tsx +++ b/ui/src/pages/species/species.tsx @@ -1,3 +1,4 @@ +import { DefaultFiltersControl } from 'components/filtering/default-filter-control' import { FilterControl } from 'components/filtering/filter-control' import { FilterSection } from 'components/filtering/filter-section' import { useProjectDetails } from 'data-services/hooks/projects/useProjectDetails' @@ -23,7 +24,6 @@ import { STRING, translate } from 'utils/language' import { useColumnSettings } from 'utils/useColumnSettings' import { useFilters } from 'utils/useFilters' import { usePagination } from 'utils/usePagination' -import { useUserPreferences } from 'utils/userPreferences/userPreferencesContext' import { useSelectedView } from 'utils/useSelectedView' import { useSort } from 'utils/useSort' import { columns } from './species-columns' @@ -42,12 +42,9 @@ export const Species = () => { 'created-at': false, 'updated-at': false, }) - const { userPreferences } = useUserPreferences() const { sort, setSort } = useSort({ field: 'name', order: 'asc' }) const { pagination, setPage } = usePagination() - const { filters } = useFilters({ - best_determination_score: `${userPreferences.scoreThreshold}`, - }) + const { filters } = useFilters() const { species, total, isLoading, isFetching, error } = useSpecies({ projectId, sort, @@ -68,7 +65,6 @@ export const Species = () => { {taxaLists.length > 0 && ( )} - {project?.featureFlags.tags ? ( <> @@ -76,6 +72,7 @@ export const Species = () => { ) : null} +
{ diff --git a/ui/src/utils/userPreferences/constants.ts b/ui/src/utils/userPreferences/constants.ts index f5dff8749..742381e12 100644 --- a/ui/src/utils/userPreferences/constants.ts +++ b/ui/src/utils/userPreferences/constants.ts @@ -5,6 +5,5 @@ export const USER_PREFERENCES_STORAGE_KEY = 'ami-user-preferences' export const DEFAULT_PREFERENCES: UserPreferences = { columnSettings: {}, recentIdentifications: [], - scoreThreshold: 0.6, termsMessageSeen: false, } diff --git a/ui/src/utils/userPreferences/types.ts b/ui/src/utils/userPreferences/types.ts index 474ea0fca..ee1d8d23d 100644 --- a/ui/src/utils/userPreferences/types.ts +++ b/ui/src/utils/userPreferences/types.ts @@ -5,7 +5,6 @@ export interface UserPreferences { label: string value: string }[] - scoreThreshold: number termsMessageSeen?: boolean }