Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,20 @@ class SourceImageCollectionCommonKwargsSerializer(serializers.Serializer):
allow_empty=True,
)

event_ids = serializers.ListField(
child=serializers.IntegerField(),
required=False,
allow_null=True,
allow_empty=True,
)

research_site_ids = serializers.ListField(
child=serializers.IntegerField(),
required=False,
allow_null=True,
allow_empty=True,
)

# Kwargs for other sampling methods, this is not complete
# see the SourceImageCollection model for all available kwargs.
size = serializers.IntegerField(required=False, allow_null=True)
Expand All @@ -1110,8 +1124,6 @@ def to_representation(self, instance):


class SourceImageCollectionSerializer(DefaultSerializer):
# @TODO can sampling kwargs be a nested serializer instead??

source_images = serializers.SerializerMethodField()
kwargs = SourceImageCollectionCommonKwargsSerializer(required=False, partial=True)
jobs = JobStatusSerializer(many=True, read_only=True)
Expand Down
33 changes: 33 additions & 0 deletions ami/main/migrations/0063_alter_sourceimagecollection_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.10 on 2025-07-24 21:26

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("main", "0062_project_feature_flags"),
]

operations = [
migrations.AlterField(
model_name="sourceimagecollection",
name="method",
field=models.CharField(
choices=[
("full", "full"),
("random", "random"),
("stratified_random", "stratified_random"),
("interval", "interval"),
("manual", "manual"),
("starred", "starred"),
("random_from_each_event", "random_from_each_event"),
("last_and_random_from_each_event", "last_and_random_from_each_event"),
("greatest_file_size_from_each_event", "greatest_file_size_from_each_event"),
("detections_only", "detections_only"),
("common_combined", "common_combined"),
],
default="full",
max_length=255,
),
),
]
39 changes: 29 additions & 10 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def get_or_create_default_collection(project: "Project") -> "SourceImageCollecti
collection, _created = SourceImageCollection.objects.get_or_create(
name="All Images",
project=project,
method="full",
# @TODO make this a dynamic collection that updates automatically
)
logger.info(f"Created default collection for project {project}")
return collection
Expand Down Expand Up @@ -3141,7 +3143,7 @@ def html(self) -> str:


_SOURCE_IMAGE_SAMPLING_METHODS = [
"common_combined", # Deprecated
"full",
"random",
"stratified_random",
"interval",
Expand All @@ -3151,7 +3153,7 @@ def html(self) -> str:
"last_and_random_from_each_event",
"greatest_file_size_from_each_event",
"detections_only",
"full",
"common_combined", # Deprecated
]


Expand Down Expand Up @@ -3234,7 +3236,7 @@ class SourceImageCollection(BaseModel):
method = models.CharField(
max_length=255,
choices=as_choices(_SOURCE_IMAGE_SAMPLING_METHODS),
default="common_combined",
default="full",
)
# @TODO this should be a JSON field with a schema, use a pydantic model
kwargs = models.JSONField(
Expand Down Expand Up @@ -3279,13 +3281,8 @@ def taxa_count(self) -> int | None:

def get_queryset(
self,
hour_start: int | None = None,
hour_end: int | None = None,
month_start: int | None = None,
month_end: int | None = None,
date_start: str | None = None,
date_end: str | None = None,
deployment_ids: list[int] | None = None,
*args,
**kwargs,
):
return SourceImage.objects.filter(project=self.project)

Expand Down Expand Up @@ -3323,9 +3320,15 @@ def _filter_sample(
date_start: str | None = None,
date_end: str | None = None,
deployment_ids: list[int] | None = None,
research_site_ids: list[int] | None = None,
event_ids: list[int] | None = None,
):
if deployment_ids is not None:
qs = qs.filter(deployment__in=deployment_ids)
if research_site_ids is not None:
qs = qs.filter(deployment__research_site__in=research_site_ids)
if event_ids is not None:
qs = qs.filter(event__in=event_ids)
if date_start is not None:
qs = qs.filter(timestamp__date__gte=DateStringField.to_date(date_start))
if date_end is not None:
Expand Down Expand Up @@ -3360,6 +3363,8 @@ def sample_random(
date_start: str | None = None,
date_end: str | None = None,
deployment_ids: list[int] | None = None,
research_site_ids: list[int] | None = None,
event_ids: list[int] | None = None,
):
"""Create a random sample of source images"""

Expand All @@ -3373,6 +3378,8 @@ def sample_random(
date_start=date_start,
date_end=date_end,
deployment_ids=deployment_ids,
research_site_ids=research_site_ids,
event_ids=event_ids,
)
return qs.order_by("?")[:size]

Expand All @@ -3395,6 +3402,8 @@ def sample_common_combined(
date_start: str | None = None,
date_end: str | None = None,
deployment_ids: list[int] | None = None,
research_site_ids: list[int] | None = None,
event_ids: list[int] | None = None,
) -> models.QuerySet | typing.Generator[SourceImage, None, None]:
qs = self.get_queryset()
qs = self._filter_sample(
Expand All @@ -3406,6 +3415,8 @@ def sample_common_combined(
date_start=date_start,
date_end=date_end,
deployment_ids=deployment_ids,
research_site_ids=research_site_ids,
event_ids=event_ids,
)

if minute_interval is not None:
Expand Down Expand Up @@ -3434,6 +3445,8 @@ def sample_interval(
date_start: str | None = None,
date_end: str | None = None,
deployment_ids: list[int] | None = None,
research_site_ids: list[int] | None = None,
event_ids: list[int] | None = None,
):
"""Create a sample of source images based on a time interval"""

Expand All @@ -3447,6 +3460,8 @@ def sample_interval(
date_start=date_start,
date_end=date_end,
deployment_ids=deployment_ids,
research_site_ids=research_site_ids,
event_ids=event_ids,
)
if deployment_id:
qs = qs.filter(deployment=deployment_id)
Expand Down Expand Up @@ -3516,6 +3531,8 @@ def sample_full(
date_start: str | None = None,
date_end: str | None = None,
deployment_ids: list[int] | None = None,
research_site_ids: list[int] | None = None,
event_ids: list[int] | None = None,
):
"""Sample all source images"""

Expand All @@ -3529,6 +3546,8 @@ def sample_full(
date_start=date_start,
date_end=date_end,
deployment_ids=deployment_ids,
research_site_ids=research_site_ids,
event_ids=event_ids,
)
return qs.all().distinct()

Expand Down
3 changes: 3 additions & 0 deletions ui/src/components/form/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ export interface FieldConfig {
max?: number
validate?: (value: any) => string | undefined
}
// Processor functions for field value transformation
toApiValue?: (formValue: any) => any // Form → API
toFormValue?: (apiValue: any) => any // API → Form
}

export type FormConfig = {
Expand Down
2 changes: 1 addition & 1 deletion ui/src/pages/project/collections/constants.tsx
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
// Only some sampling methods are editable from the UI
export const SERVER_SAMPLING_METHODS = ['interval', 'random', 'full']
export const SERVER_SAMPLING_METHODS = ['full', 'interval', 'random']
101 changes: 64 additions & 37 deletions ui/src/pages/project/entities/details-form/collection-details-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ import { XIcon } from 'lucide-react'
import { Button, Select } from 'nova-ui-kit'
import { SERVER_SAMPLING_METHODS } from 'pages/project/collections/constants'
import { useForm } from 'react-hook-form'
import {
formatIntegerList,
parseIntegerList,
validateInteger,
validateIntegerList,
} from 'utils/fieldProcessors'
import { STRING, translate } from 'utils/language'
import { snakeCaseToSentenceCase } from 'utils/snakeCaseToSentenceCase'
import { useFormError } from 'utils/useFormError'
Expand All @@ -30,6 +36,9 @@ type CollectionFormValues = FormValues & {
max_num: number | undefined
minute_interval: number | undefined
size: number | undefined
deployment_ids: string | undefined
research_site_ids: string | undefined
event_ids: string | undefined
}
}

Expand Down Expand Up @@ -61,13 +70,7 @@ const config: FormConfig = {
rules: {
min: 0,
max: 24,
validate: (value) => {
if (value) {
if (!Number.isInteger(Number(value))) {
return translate(STRING.MESSAGE_VALUE_INVALID)
}
}
},
validate: validateInteger,
},
},
'kwargs.hour_end': {
Expand All @@ -76,56 +79,50 @@ const config: FormConfig = {
rules: {
min: 0,
max: 24,
validate: (value) => {
if (value) {
if (!Number.isInteger(Number(value))) {
return translate(STRING.MESSAGE_VALUE_INVALID)
}
}
},
validate: validateInteger,
},
},
'kwargs.max_num': {
label: 'Max number of captures',
rules: {
min: 0,
validate: (value) => {
if (value) {
if (!Number.isInteger(Number(value))) {
return translate(STRING.MESSAGE_VALUE_INVALID)
}
}
},
validate: validateInteger,
},
},
'kwargs.minute_interval': {
label: 'Minutes between captures',
rules: {
min: 0,
required: true,
validate: (value) => {
if (value) {
if (!Number.isInteger(Number(value))) {
return translate(STRING.MESSAGE_VALUE_INVALID)
}
}
},
validate: validateInteger,
},
},
'kwargs.size': {
label: 'Number of captures',
rules: {
min: 0,
required: true,
validate: (value) => {
if (value) {
if (!Number.isInteger(Number(value))) {
return translate(STRING.MESSAGE_VALUE_INVALID)
}
}
},
validate: validateInteger,
},
},
'kwargs.deployment_ids': {
label: 'Station IDs',
description: 'Enter comma-separated numbers (e.g., 1, 2, 3).',
rules: {
validate: validateIntegerList,
},
toApiValue: parseIntegerList,
toFormValue: formatIntegerList,
},
'kwargs.event_ids': {
label: 'Session IDs',
description: 'Enter comma-separated numbers (e.g., 1, 2, 3).',
rules: {
validate: validateIntegerList,
},
toApiValue: parseIntegerList,
toFormValue: formatIntegerList,
},
}

export const CollectionDetailsForm = ({
Expand All @@ -142,9 +139,17 @@ export const CollectionDetailsForm = ({
name: entity?.name ?? '',
description: entity?.description ?? '',
kwargs: {
...(collection?.kwargs ? collection.kwargs : {}),
minute_interval: 10,
size: 100,
...Object.fromEntries(
Object.entries(collection?.kwargs || {}).map(([key, value]) => {
const fieldConfig = config[`kwargs.${key}`]
const formValue = fieldConfig?.toFormValue
? fieldConfig.toFormValue(value)
: value
return [key, formValue]
})
),
},
method: collection?.method ?? SERVER_SAMPLING_METHODS[0],
},
Expand All @@ -170,7 +175,15 @@ export const CollectionDetailsForm = ({

return true
})
.map(([key, value]) => [key, value === '' ? null : value])
.map(([key, value]) => {
const fieldConfig = config[`kwargs.${key}`]
const processedValue = fieldConfig?.toApiValue
? fieldConfig.toApiValue(value)
: value === ''
? null
: value
return [key, processedValue]
})
)

onSubmit({
Expand Down Expand Up @@ -287,6 +300,20 @@ export const CollectionDetailsForm = ({
control={control}
/>
</FormRow>
<FormRow>
<FormField
name="kwargs.deployment_ids"
type="text"
config={config}
control={control}
/>
<FormField
name="kwargs.event_ids"
type="text"
config={config}
control={control}
/>
</FormRow>
</FormSection>
<FormSection>
<h3 className="body-large font-bold text-muted-foreground/50">
Expand Down
Loading