Skip to content

Commit d2cfe9f

Browse files
committed
Merge branch 'main' into feat/update-identifications
2 parents 977b4e2 + 78a0cb2 commit d2cfe9f

File tree

27 files changed

+445
-140
lines changed

27 files changed

+445
-140
lines changed

ami/jobs/apps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from django.utils.translation import gettext_lazy as _
33

44

5-
class UsersConfig(AppConfig):
5+
class JobsConfig(AppConfig):
66
name = "ami.jobs"
77
verbose_name = _("Jobs")

ami/jobs/models.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,9 @@ def run(cls, job: "Job"):
399399
total_detections = 0
400400
total_classifications = 0
401401

402-
# Set to low size because our response JSON just got enormous
403-
# @TODO make this configurable
404-
CHUNK_SIZE = 1
405-
chunks = [images[i : i + CHUNK_SIZE] for i in range(0, image_count, CHUNK_SIZE)] # noqa
402+
config = job.pipeline.get_config(project_id=job.project.pk)
403+
chunk_size = config.get("request_source_image_batch_size", 1)
404+
chunks = [images[i : i + chunk_size] for i in range(0, image_count, chunk_size)] # noqa
406405
request_failed_images = []
407406

408407
for i, chunk in enumerate(chunks):
@@ -434,9 +433,9 @@ def run(cls, job: "Job"):
434433
"process",
435434
status=JobState.STARTED,
436435
progress=(i + 1) / len(chunks),
437-
processed=min((i + 1) * CHUNK_SIZE, image_count),
436+
processed=min((i + 1) * chunk_size, image_count),
438437
failed=len(request_failed_images),
439-
remaining=max(image_count - ((i + 1) * CHUNK_SIZE), 0),
438+
remaining=max(image_count - ((i + 1) * chunk_size), 0),
440439
)
441440

442441
# count the completed, successful, and failed save_tasks:

ami/labelstudio/apps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from django.utils.translation import gettext_lazy as _
33

44

5-
class UsersConfig(AppConfig):
5+
class LabelStudioConfig(AppConfig):
66
name = "ami.labelstudio"
77
verbose_name = _("Label Studio Integration")

ami/main/admin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,13 @@ class TaxaListAdmin(admin.ModelAdmin[TaxaList]):
517517
def taxa_count(self, obj) -> int:
518518
return obj.taxa.count()
519519

520+
autocomplete_fields = (
521+
"taxa",
522+
"projects",
523+
)
524+
525+
list_filter = ("projects",)
526+
520527

521528
@admin.register(Device)
522529
class DeviceAdmin(admin.ModelAdmin[Device]):

ami/main/api/serializers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
SourceImage,
3131
SourceImageCollection,
3232
SourceImageUpload,
33+
TaxaList,
3334
Taxon,
3435
)
3536

@@ -532,6 +533,25 @@ def get_occurrences(self, obj):
532533
)
533534

534535

536+
class TaxaListSerializer(serializers.ModelSerializer):
537+
taxa = serializers.SerializerMethodField()
538+
projects = serializers.PrimaryKeyRelatedField(queryset=Project.objects.all(), many=True)
539+
540+
class Meta:
541+
model = TaxaList
542+
fields = ["id", "name", "description", "taxa", "projects"]
543+
544+
def get_taxa(self, obj):
545+
"""
546+
Return URL to the taxa endpoint filtered by this taxalist.
547+
"""
548+
return reverse_with_params(
549+
"taxon-list",
550+
request=self.context.get("request"),
551+
params={"taxa_list_id": obj.pk},
552+
)
553+
554+
535555
class CaptureTaxonSerializer(DefaultSerializer):
536556
parent = TaxonNoParentNestedSerializer(read_only=True)
537557
parents = TaxonParentSerializer(many=True, read_only=True)

ami/main/api/views.py

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from django.contrib.postgres.search import TrigramSimilarity
66
from django.core import exceptions
77
from django.db import models
8-
from django.db.models import Prefetch
8+
from django.db.models import Prefetch, Q
99
from django.db.models.functions import Coalesce
1010
from django.db.models.query import QuerySet
1111
from django.forms import BooleanField, CharField, IntegerField
@@ -59,6 +59,7 @@
5959
SourceImage,
6060
SourceImageCollection,
6161
SourceImageUpload,
62+
TaxaList,
6263
Taxon,
6364
User,
6465
update_detection_counts,
@@ -89,6 +90,7 @@
8990
SourceImageUploadSerializer,
9091
StorageSourceSerializer,
9192
StorageStatusSerializer,
93+
TaxaListSerializer,
9294
TaxonListSerializer,
9395
TaxonSearchResultSerializer,
9496
TaxonSerializer,
@@ -967,6 +969,38 @@ def filter_queryset(self, request, queryset, view):
967969
return queryset
968970

969971

972+
class OccurrenceTaxaListFilter(filters.BaseFilterBackend):
973+
"""
974+
Filters occurrences based on a TaxaList.
975+
976+
Queries for all occurrences where the determination taxon is either:
977+
- Directly in the requested TaxaList.
978+
- A descendant (child or deeper) of any taxon in the TaxaList, recursively.
979+
980+
"""
981+
982+
query_param = "taxa_list_id"
983+
984+
def filter_queryset(self, request, queryset, view):
985+
taxalist_id = IntegerField(required=False).clean(request.query_params.get(self.query_param))
986+
if taxalist_id:
987+
taxa_list = TaxaList.objects.filter(id=taxalist_id).first()
988+
if taxa_list:
989+
taxa = taxa_list.taxa.all() # Get taxalist taxon objects
990+
991+
# filter by the exact determination
992+
query_filter = Q(determination__in=taxa)
993+
994+
# filter by the taxon's children
995+
for taxon in taxa:
996+
query_filter |= Q(determination__parents_json__contains=[{"id": taxon.pk}])
997+
998+
queryset = queryset.filter(query_filter)
999+
return queryset
1000+
1001+
return queryset
1002+
1003+
9701004
class TaxonCollectionFilter(filters.BaseFilterBackend):
9711005
"""
9721006
Filter taxa by the collection their occurrences belong to.
@@ -999,6 +1033,7 @@ class OccurrenceViewSet(DefaultViewSet, ProjectMixin):
9991033
OccurrenceDateFilter,
10001034
OccurrenceVerified,
10011035
OccurrenceVerifiedByMeFilter,
1036+
OccurrenceTaxaListFilter,
10021037
]
10031038
filterset_fields = [
10041039
"event",
@@ -1030,9 +1065,9 @@ def get_serializer_class(self):
10301065
else:
10311066
return OccurrenceSerializer
10321067

1033-
def get_queryset(self) -> QuerySet:
1068+
def get_queryset(self) -> QuerySet["Occurrence"]:
10341069
project = self.get_active_project()
1035-
qs = super().get_queryset()
1070+
qs = super().get_queryset().valid() # type: ignore
10361071
if project:
10371072
qs = qs.filter(project=project)
10381073
qs = qs.select_related(
@@ -1046,10 +1081,7 @@ def get_queryset(self) -> QuerySet:
10461081
if self.action == "list":
10471082
qs = (
10481083
qs.all()
1049-
.exclude(detections=None)
1050-
.exclude(event=None)
10511084
.filter(determination_score__gte=get_active_classification_threshold(self.request))
1052-
.exclude(first_appearance_timestamp=None) # This must come after annotations
10531085
.order_by("-determination_score")
10541086
)
10551087

@@ -1067,14 +1099,45 @@ def list(self, request, *args, **kwargs):
10671099
return super().list(request, *args, **kwargs)
10681100

10691101

1102+
class TaxonTaxaListFilter(filters.BaseFilterBackend):
1103+
"""
1104+
Filters taxa based on a TaxaList Similar to `OccurrenceTaxaListFilter`.
1105+
1106+
Queries for all taxa that are either:
1107+
- Directly in the requested TaxaList.
1108+
- A descendant (child or deeper) of any taxon in the TaxaList, recursively.
1109+
"""
1110+
1111+
query_param = "taxa_list_id"
1112+
1113+
def filter_queryset(self, request, queryset, view):
1114+
taxalist_id = IntegerField(required=False).clean(request.query_params.get(self.query_param))
1115+
if taxalist_id:
1116+
taxa_list = TaxaList.objects.filter(id=taxalist_id).first()
1117+
if taxa_list:
1118+
taxa = taxa_list.taxa.all() # Get taxa in the TaxaList
1119+
query_filter = Q(id__in=taxa)
1120+
for taxon in taxa:
1121+
query_filter |= Q(parents_json__contains=[{"id": taxon.pk}])
1122+
1123+
queryset = queryset.filter(query_filter)
1124+
return queryset
1125+
1126+
return queryset
1127+
1128+
10701129
class TaxonViewSet(DefaultViewSet, ProjectMixin):
10711130
"""
10721131
API endpoint that allows taxa to be viewed or edited.
10731132
"""
10741133

10751134
queryset = Taxon.objects.all().defer("notes")
10761135
serializer_class = TaxonSerializer
1077-
filter_backends = DefaultViewSetMixin.filter_backends + [CustomTaxonFilter, TaxonCollectionFilter]
1136+
filter_backends = DefaultViewSetMixin.filter_backends + [
1137+
CustomTaxonFilter,
1138+
TaxonCollectionFilter,
1139+
TaxonTaxaListFilter,
1140+
]
10781141
filterset_fields = [
10791142
"name",
10801143
"rank",
@@ -1286,6 +1349,19 @@ def list(self, request, *args, **kwargs):
12861349
return super().list(request, *args, **kwargs)
12871350

12881351

1352+
class TaxaListViewSet(viewsets.ModelViewSet, ProjectMixin):
1353+
queryset = TaxaList.objects.all()
1354+
1355+
def get_queryset(self):
1356+
qs = super().get_queryset()
1357+
project = self.get_active_project()
1358+
if project:
1359+
return qs.filter(projects=project)
1360+
return qs
1361+
1362+
serializer_class = TaxaListSerializer
1363+
1364+
12891365
class ClassificationViewSet(DefaultViewSet, ProjectMixin):
12901366
"""
12911367
API endpoint for viewing and adding classification results from a model.
@@ -1343,11 +1419,7 @@ def get(self, request):
13431419
"events_count": Event.objects.filter(deployment__project=project, deployment__isnull=False).count(),
13441420
"captures_count": SourceImage.objects.filter(deployment__project=project).count(),
13451421
# "detections_count": Detection.objects.filter(occurrence__project=project).count(),
1346-
"occurrences_count": Occurrence.objects.filter(
1347-
project=project,
1348-
# determination_score__gte=confidence_threshold,
1349-
event__isnull=False,
1350-
).count(),
1422+
"occurrences_count": Occurrence.objects.valid().filter(project=project).count(), # type: ignore
13511423
"taxa_count": Occurrence.objects.all().unique_taxa(project=project).count(), # type: ignore
13521424
}
13531425
else:
@@ -1357,10 +1429,7 @@ def get(self, request):
13571429
"events_count": Event.objects.filter(deployment__isnull=False).count(),
13581430
"captures_count": SourceImage.objects.count(),
13591431
# "detections_count": Detection.objects.count(),
1360-
"occurrences_count": Occurrence.objects.filter(
1361-
# determination_score__gte=confidence_threshold,
1362-
event__isnull=False
1363-
).count(),
1432+
"occurrences_count": Occurrence.objects.valid().count(), # type: ignore
13641433
"taxa_count": Occurrence.objects.all().unique_taxa().count(), # type: ignore
13651434
"last_updated": timezone.now(),
13661435
}

ami/main/models.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,11 +2174,14 @@ def __str__(self) -> str:
21742174
return f"#{self.pk} from SourceImage #{self.source_image_id} with Algorithm #{self.detection_algorithm_id}"
21752175

21762176

2177-
class OccurrenceQuerySet(models.QuerySet):
2178-
def with_detections_count(self) -> models.QuerySet:
2177+
class OccurrenceQuerySet(models.QuerySet["Occurrence"]):
2178+
def valid(self):
2179+
return self.exclude(detections__isnull=True)
2180+
2181+
def with_detections_count(self):
21792182
return self.annotate(detections_count=models.Count("detections", distinct=True))
21802183

2181-
def with_timestamps(self) -> models.QuerySet:
2184+
def with_timestamps(self):
21822185
"""
21832186
These are timestamps used for filtering and ordering in the UI.
21842187
"""
@@ -2192,14 +2195,14 @@ def with_timestamps(self) -> models.QuerySet:
21922195
),
21932196
)
21942197

2195-
def with_identifications(self) -> models.QuerySet:
2198+
def with_identifications(self):
21962199
return self.prefetch_related(
21972200
"identifications",
21982201
"identifications__taxon",
21992202
"identifications__user",
22002203
)
22012204

2202-
def unique_taxa(self, project: Project | None = None) -> models.QuerySet:
2205+
def unique_taxa(self, project: Project | None = None):
22032206
qs = self
22042207
if project:
22052208
qs = self.filter(project=project)
@@ -2211,12 +2214,16 @@ def unique_taxa(self, project: Project | None = None) -> models.QuerySet:
22112214
return qs
22122215

22132216

2214-
class OccurrenceManager(models.Manager):
2215-
def get_queryset(self) -> OccurrenceQuerySet:
2216-
return OccurrenceQuerySet(self.model, using=self._db).select_related(
2217-
"determination",
2218-
"deployment",
2219-
"project",
2217+
class OccurrenceManager(models.Manager.from_queryset(OccurrenceQuerySet)):
2218+
def get_queryset(self):
2219+
return (
2220+
super()
2221+
.get_queryset()
2222+
.select_related(
2223+
"determination",
2224+
"deployment",
2225+
"project",
2226+
)
22202227
)
22212228

22222229

@@ -2236,7 +2243,7 @@ class Occurrence(BaseModel):
22362243
detections: models.QuerySet[Detection]
22372244
identifications: models.QuerySet[Identification]
22382245

2239-
objects: OccurrenceManager = OccurrenceManager()
2246+
objects = OccurrenceManager()
22402247

22412248
def __str__(self) -> str:
22422249
name = f"Occurrence #{self.pk}"

ami/ml/apps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from django.utils.translation import gettext_lazy as _
33

44

5-
class UsersConfig(AppConfig):
5+
class MLConfig(AppConfig):
66
name = "ami.ml"
77
verbose_name = _("Machine Learning")
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Generated by Django 4.2.10 on 2025-03-19 16:27
2+
3+
import ami.ml.schemas
4+
from django.db import migrations
5+
import django_pydantic_field.fields
6+
7+
8+
class Migration(migrations.Migration):
9+
dependencies = [
10+
("ml", "0020_projectpipelineconfig_alter_pipeline_projects"),
11+
]
12+
13+
operations = [
14+
migrations.AddField(
15+
model_name="pipeline",
16+
name="default_config",
17+
field=django_pydantic_field.fields.PydanticSchemaField(
18+
config=None,
19+
default=dict,
20+
help_text="The default configuration for the pipeline. Used by both the job sending images to the pipeline and the processing service.",
21+
schema=ami.ml.schemas.PipelineRequestConfigParameters,
22+
),
23+
),
24+
]

0 commit comments

Comments
 (0)