Skip to content

Commit cbfd449

Browse files
authored
Merge branch 'main' into feat/quickstart-auto-process
2 parents 4c33807 + 41eb1e6 commit cbfd449

File tree

8 files changed

+349
-35
lines changed

8 files changed

+349
-35
lines changed

.envs/.local/.django

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@ MINIO_DEFAULT_BUCKET=ami
3737
MINIO_STORAGE_USE_HTTPS=False
3838
MINIO_TEST_BUCKET=ami-test
3939
MINIO_BROWSER_REDIRECT_URL=http://minio:9001
40+
41+
# Default processing service (local)
42+
DEFAULT_PROCESSING_SERVICE_NAME=Local Processing Service
43+
DEFAULT_PROCESSING_SERVICE_ENDPOINT=http://ml_backend:2000
44+
# DEFAULT_PIPELINES_ENABLED=random,constant # When set to None, all pipelines will be enabled.

.envs/.production/.django-example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,8 @@ DJANGO_ACCOUNT_ALLOW_REGISTRATION=True
5353
# Gunicorn
5454
# ------------------------------------------------------------------------------
5555
WEB_CONCURRENCY=4
56+
57+
# Default processing service
58+
DEFAULT_PROCESSING_SERVICE_NAME="AMI Data Companion"
59+
DEFAULT_PROCESSING_SERVICE_ENDPOINT=https://ml.antenna.insectai.org/
60+
DEFAULT_PIPELINES_ENABLED=global_moths_2024,quebec_vermont_moths_2023,panama_moths_2023,uk_denmark_moths_2023

ami/main/models.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from django.contrib.postgres.fields import ArrayField
1616
from django.core.exceptions import ValidationError
1717
from django.core.files.storage import default_storage
18-
from django.db import IntegrityError, models
18+
from django.db import IntegrityError, models, transaction
1919
from django.db.models import Q
2020
from django.db.models.fields.files import ImageFieldFile
2121
from django.db.models.signals import pre_delete
@@ -34,6 +34,7 @@
3434

3535
if typing.TYPE_CHECKING:
3636
from ami.jobs.models import Job
37+
from ami.ml.models import ProcessingService
3738

3839
logger = logging.getLogger(__name__)
3940

@@ -89,20 +90,57 @@ def get_media_url(path: str) -> str:
8990
as_choices = lambda x: [(i, i) for i in x] # noqa: E731
9091

9192

92-
def create_default_device(project: "Project") -> "Device":
93+
def get_or_create_default_device(project: "Project") -> "Device":
9394
"""Create a default device for a project."""
9495
device, _created = Device.objects.get_or_create(name="Default device", project=project)
9596
logger.info(f"Created default device for project {project}")
9697
return device
9798

9899

99-
def create_default_research_site(project: "Project") -> "Site":
100+
def get_or_create_default_research_site(project: "Project") -> "Site":
100101
"""Create a default research site for a project."""
101102
site, _created = Site.objects.get_or_create(name="Default site", project=project)
102103
logger.info(f"Created default research site for project {project}")
103104
return site
104105

105106

107+
def get_or_create_default_deployment(
108+
project: "Project", site: "Site | None" = None, device: "Device | None" = None
109+
) -> "Deployment":
110+
"""Create a default deployment for a project."""
111+
deployment, _created = Deployment.objects.get_or_create(
112+
name="Default Station",
113+
project=project,
114+
research_site=site,
115+
device=device,
116+
)
117+
logger.info(f"Created default deployment for project {project}")
118+
return deployment
119+
120+
121+
def get_or_create_default_collection(project: "Project") -> "SourceImageCollection":
122+
"""Create a default collection for a project for all images, updated dynamically."""
123+
collection, _created = SourceImageCollection.objects.get_or_create(
124+
name="All Images",
125+
project=project,
126+
)
127+
logger.info(f"Created default collection for project {project}")
128+
return collection
129+
130+
131+
def get_or_create_default_project(user: User) -> "Project":
132+
"""
133+
Create a default project for a user.
134+
135+
Default related objects like devices and research sites will be created
136+
when the project is saved for the first time.
137+
If the project already exists, it will be returned without modification.
138+
"""
139+
project, _created = Project.objects.get_or_create(name="Scratch Project", owner=user, create_defaults=True)
140+
logger.info(f"Created default project for user {user}")
141+
return project
142+
143+
106144
class ProjectQuerySet(models.QuerySet):
107145
def filter_by_user(self, user: User):
108146
"""
@@ -115,6 +153,39 @@ class ProjectManager(models.Manager):
115153
def get_queryset(self) -> ProjectQuerySet:
116154
return ProjectQuerySet(self.model, using=self._db)
117155

156+
def create(self, create_defaults: bool = True, **kwargs) -> "Project":
157+
"""
158+
Create a new Project and related models with defaults.
159+
160+
Args:
161+
create_defaults: Whether to create default related models
162+
**kwargs: Model field values
163+
164+
Returns:
165+
Created Project instance
166+
"""
167+
with transaction.atomic():
168+
project_instance = super().create(**kwargs)
169+
logger.info(f"Created project: {project_instance.name}")
170+
171+
if create_defaults:
172+
self.create_related_defaults(project_instance)
173+
174+
return project_instance
175+
176+
def create_related_defaults(self, project: "Project"):
177+
"""Create default device, and other related models for this project if they don't exist."""
178+
device = get_or_create_default_device(project=project)
179+
site = get_or_create_default_research_site(project=project)
180+
if not project.deployments.exists():
181+
get_or_create_default_deployment(project=project, site=site, device=device)
182+
if not project.sourceimage_collections.exists():
183+
get_or_create_default_collection(project=project)
184+
if not project.processing_services.exists():
185+
from ami.ml.models.processing_service import get_or_create_default_processing_service
186+
187+
get_or_create_default_processing_service(project=project)
188+
118189

119190
@final
120191
class Project(BaseModel):
@@ -140,6 +211,8 @@ class Project(BaseModel):
140211
devices: models.QuerySet["Device"]
141212
sites: models.QuerySet["Site"]
142213
jobs: models.QuerySet["Job"]
214+
sourceimage_collections: models.QuerySet["SourceImageCollection"]
215+
processing_services: models.QuerySet["ProcessingService"]
143216

144217
objects = ProjectManager()
145218

@@ -177,21 +250,10 @@ def summary_data(self):
177250

178251
return plots
179252

180-
def create_related_defaults(self):
181-
"""Create default device, and other related models for this project if they don't exist."""
182-
if not self.devices.exists():
183-
create_default_device(project=self)
184-
if not self.sites.exists():
185-
create_default_research_site(project=self)
186-
187253
def save(self, *args, **kwargs):
188-
new_project = bool(self._state.adding)
189254
super().save(*args, **kwargs)
190255
# Add owner to members
191256
self.ensure_owner_membership()
192-
if new_project:
193-
logger.info(f"Created new project {self}")
194-
self.create_related_defaults()
195257

196258
class Permissions:
197259
"""CRUD Permission names follow the convention: `create_<model>`, `update_<model>`,
@@ -1192,6 +1254,7 @@ class S3StorageSource(BaseModel):
11921254
# last_check_duration = models.DurationField(null=True, blank=True)
11931255
# use_signed_urls = models.BooleanField(default=False)
11941256
project = models.ForeignKey(Project, on_delete=models.SET_NULL, null=True, related_name="storage_sources")
1257+
# @TODO allow multiple projects to share the same S3StorageSource
11951258

11961259
deployments: models.QuerySet["Deployment"]
11971260

ami/main/tests.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from django.core.files.uploadedfile import SimpleUploadedFile
66
from django.db import connection, models
7-
from django.test import TestCase
7+
from django.test import TestCase, override_settings
88
from guardian.shortcuts import get_perms
99
from PIL import Image
1010
from rest_framework import status
@@ -29,6 +29,7 @@
2929
group_images_into_events,
3030
)
3131
from ami.ml.models.pipeline import Pipeline
32+
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
3233
from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project
3334
from ami.tests.fixtures.storage import populate_bucket
3435
from ami.users.models import User
@@ -37,6 +38,170 @@
3738
logger = logging.getLogger(__name__)
3839

3940

41+
class TestProjectSetup(TestCase):
42+
def test_project_creation(self):
43+
project = Project.objects.create(name="New Project with Defaults", create_defaults=True)
44+
self.assertIsInstance(project, Project)
45+
46+
def test_default_related_models(self):
47+
"""Test that the default related models are created correctly when a project is created."""
48+
project = Project.objects.create(name="New Project with Defaults", create_defaults=True)
49+
50+
# Check that the project has a default deployment
51+
self.assertGreaterEqual(project.deployments.count(), 1)
52+
deployment = project.deployments.first()
53+
self.assertIsInstance(deployment, Deployment)
54+
55+
# Check that the deployment has a default site
56+
self.assertGreaterEqual(project.sites.count(), 1)
57+
site = project.sites.first()
58+
self.assertIsInstance(site, Site)
59+
60+
# Check that the deployment has a default device
61+
self.assertGreaterEqual(project.devices.count(), 1)
62+
device = project.devices.first()
63+
self.assertIsInstance(device, Device)
64+
65+
# Check that the project has a default source image collection
66+
self.assertGreaterEqual(project.sourceimage_collections.count(), 1)
67+
collection = project.sourceimage_collections.first()
68+
self.assertIsInstance(collection, SourceImageCollection)
69+
70+
# Disable this test for now, as it requires a more complex setup
71+
def no_test_default_permissions(self):
72+
pass
73+
74+
@override_settings(
75+
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
76+
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2009/",
77+
)
78+
def test_processing_service_if_configured(self):
79+
"""
80+
Test that the default processing service is created if the environment variables are set.
81+
"""
82+
from ami.ml.models.processing_service import get_or_create_default_processing_service
83+
84+
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=False)
85+
86+
service = get_or_create_default_processing_service(project=project, register_pipelines=False)
87+
self.assertIsNotNone(service, "Default processing service should be created if environment variables are set.")
88+
assert service is not None # For type checking
89+
self.assertIsNotNone(service.endpoint_url)
90+
self.assertIsNotNone(service.name)
91+
self.assertGreaterEqual(project.processing_services.count(), 1)
92+
93+
@override_settings(
94+
DEFAULT_PROCESSING_SERVICE_NAME=None,
95+
DEFAULT_PROCESSING_SERVICE_ENDPOINT=None,
96+
)
97+
def test_processing_service_if_not_configured(self):
98+
"""
99+
Test that the default processing service is not created if the environment variables are not set.
100+
"""
101+
from ami.ml.models.processing_service import get_or_create_default_processing_service
102+
103+
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=False)
104+
105+
service = get_or_create_default_processing_service(project=project)
106+
self.assertIsNone(
107+
service, "Default processing service should not be created if environment variables are not set."
108+
)
109+
110+
@override_settings(
111+
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
112+
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/",
113+
DEFAULT_PIPELINES_ENABLED=[], # All pipelines DISABLED by default
114+
)
115+
def test_processing_service_with_disabled_pipelines(self):
116+
"""
117+
Test that the default processing service is created with all pipelines disabled
118+
if DEFAULT_PIPELINES_ENABLED is any empty list.
119+
"""
120+
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=True)
121+
processing_service = project.processing_services.first()
122+
assert processing_service is not None
123+
# There should be at least two pipelines created by default
124+
self.assertGreaterEqual(processing_service.pipelines.count(), 2)
125+
# All pipelines should be disabled by default
126+
project_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project)
127+
for config in project_pipeline_configs:
128+
self.assertFalse(
129+
config.enabled,
130+
f"Pipeline {config.pipeline.name} should be disabled for project {project.name}.",
131+
)
132+
133+
@override_settings(
134+
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
135+
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/",
136+
DEFAULT_PIPELINES_ENABLED=None, # All pipelines ENABLED by default
137+
)
138+
def test_processing_service_with_enabled_pipelines(self):
139+
"""
140+
Test that the default processing service is created with all pipelines enabled
141+
if the DEFAULT_PIPELINES_ENABLED setting is None (or missing).
142+
"""
143+
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=True)
144+
processing_service = project.processing_services.first()
145+
assert processing_service is not None
146+
# There should be at least two pipelines created by default
147+
self.assertGreaterEqual(processing_service.pipelines.count(), 2)
148+
# All pipelines should be enabled by default
149+
project_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project)
150+
for config in project_pipeline_configs:
151+
self.assertTrue(
152+
config.enabled,
153+
f"Pipeline {config.pipeline.name} should be enabled for project {project.name}.",
154+
)
155+
156+
@override_settings(
157+
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
158+
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/", # should have at least two pipelines
159+
DEFAULT_PIPELINES_ENABLED=["constant"],
160+
)
161+
def test_existing_processing_service_new_project(self):
162+
"""
163+
Create a new project, enable all pipelines.
164+
Create a 2nd project, ensure that the same processing service is used and only the enabled pipelines are
165+
registered.
166+
"""
167+
enabled_pipelines = ["constant"]
168+
169+
project_one = Project.objects.create(name="Test Project One", create_defaults=True)
170+
171+
# Enable all pipelines for the first project
172+
ProjectPipelineConfig.objects.filter(project=project_one).update(enabled=True)
173+
174+
project_two = Project.objects.create(name="Test Project Two", create_defaults=True)
175+
176+
project_one_processing_service = project_one.processing_services.first()
177+
project_two_processing_service = project_two.processing_services.first()
178+
179+
assert project_one_processing_service is not None
180+
assert project_two_processing_service is not None
181+
182+
# Ensure only the same processing service instance is used (and they are not None)
183+
self.assertEqual(
184+
project_one_processing_service,
185+
project_two_processing_service,
186+
"Both projects should use the same processing service instance.",
187+
)
188+
189+
# Ensure that only the enabled pipelines are enabled for the second project
190+
project_two_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project_two)
191+
self.assertGreaterEqual(project_two_pipeline_configs.count(), 2, "Project should have at least two pipelines.")
192+
for config in project_two_pipeline_configs:
193+
if config.pipeline.slug in enabled_pipelines:
194+
self.assertTrue(
195+
config.enabled,
196+
f"Pipeline {config.pipeline.name} should be enabled for project {project_two.name}.",
197+
)
198+
else:
199+
self.assertFalse(
200+
config.enabled,
201+
f"Pipeline {config.pipeline.name} should not be enabled for project {project_two.name}.",
202+
)
203+
204+
40205
class TestImageGrouping(TestCase):
41206
def setUp(self) -> None:
42207
print(f"Currently active database: {connection.settings_dict}")

0 commit comments

Comments
 (0)