|
4 | 4 |
|
5 | 5 | from django.core.files.uploadedfile import SimpleUploadedFile
|
6 | 6 | from django.db import connection, models
|
7 |
| -from django.test import TestCase |
| 7 | +from django.test import TestCase, override_settings |
8 | 8 | from guardian.shortcuts import get_perms
|
9 | 9 | from PIL import Image
|
10 | 10 | from rest_framework import status
|
|
29 | 29 | group_images_into_events,
|
30 | 30 | )
|
31 | 31 | from ami.ml.models.pipeline import Pipeline
|
| 32 | +from ami.ml.models.project_pipeline_config import ProjectPipelineConfig |
32 | 33 | from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project
|
33 | 34 | from ami.tests.fixtures.storage import populate_bucket
|
34 | 35 | from ami.users.models import User
|
|
37 | 38 | logger = logging.getLogger(__name__)
|
38 | 39 |
|
39 | 40 |
|
| 41 | +class TestProjectSetup(TestCase): |
| 42 | + def test_project_creation(self): |
| 43 | + project = Project.objects.create(name="New Project with Defaults", create_defaults=True) |
| 44 | + self.assertIsInstance(project, Project) |
| 45 | + |
| 46 | + def test_default_related_models(self): |
| 47 | + """Test that the default related models are created correctly when a project is created.""" |
| 48 | + project = Project.objects.create(name="New Project with Defaults", create_defaults=True) |
| 49 | + |
| 50 | + # Check that the project has a default deployment |
| 51 | + self.assertGreaterEqual(project.deployments.count(), 1) |
| 52 | + deployment = project.deployments.first() |
| 53 | + self.assertIsInstance(deployment, Deployment) |
| 54 | + |
| 55 | + # Check that the deployment has a default site |
| 56 | + self.assertGreaterEqual(project.sites.count(), 1) |
| 57 | + site = project.sites.first() |
| 58 | + self.assertIsInstance(site, Site) |
| 59 | + |
| 60 | + # Check that the deployment has a default device |
| 61 | + self.assertGreaterEqual(project.devices.count(), 1) |
| 62 | + device = project.devices.first() |
| 63 | + self.assertIsInstance(device, Device) |
| 64 | + |
| 65 | + # Check that the project has a default source image collection |
| 66 | + self.assertGreaterEqual(project.sourceimage_collections.count(), 1) |
| 67 | + collection = project.sourceimage_collections.first() |
| 68 | + self.assertIsInstance(collection, SourceImageCollection) |
| 69 | + |
| 70 | + # Disable this test for now, as it requires a more complex setup |
| 71 | + def no_test_default_permissions(self): |
| 72 | + pass |
| 73 | + |
| 74 | + @override_settings( |
| 75 | + DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service", |
| 76 | + DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2009/", |
| 77 | + ) |
| 78 | + def test_processing_service_if_configured(self): |
| 79 | + """ |
| 80 | + Test that the default processing service is created if the environment variables are set. |
| 81 | + """ |
| 82 | + from ami.ml.models.processing_service import get_or_create_default_processing_service |
| 83 | + |
| 84 | + project = Project.objects.create(name="Test Project for Processing Service", create_defaults=False) |
| 85 | + |
| 86 | + service = get_or_create_default_processing_service(project=project, register_pipelines=False) |
| 87 | + self.assertIsNotNone(service, "Default processing service should be created if environment variables are set.") |
| 88 | + assert service is not None # For type checking |
| 89 | + self.assertIsNotNone(service.endpoint_url) |
| 90 | + self.assertIsNotNone(service.name) |
| 91 | + self.assertGreaterEqual(project.processing_services.count(), 1) |
| 92 | + |
| 93 | + @override_settings( |
| 94 | + DEFAULT_PROCESSING_SERVICE_NAME=None, |
| 95 | + DEFAULT_PROCESSING_SERVICE_ENDPOINT=None, |
| 96 | + ) |
| 97 | + def test_processing_service_if_not_configured(self): |
| 98 | + """ |
| 99 | + Test that the default processing service is not created if the environment variables are not set. |
| 100 | + """ |
| 101 | + from ami.ml.models.processing_service import get_or_create_default_processing_service |
| 102 | + |
| 103 | + project = Project.objects.create(name="Test Project for Processing Service", create_defaults=False) |
| 104 | + |
| 105 | + service = get_or_create_default_processing_service(project=project) |
| 106 | + self.assertIsNone( |
| 107 | + service, "Default processing service should not be created if environment variables are not set." |
| 108 | + ) |
| 109 | + |
| 110 | + @override_settings( |
| 111 | + DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service", |
| 112 | + DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/", |
| 113 | + DEFAULT_PIPELINES_ENABLED=[], # All pipelines DISABLED by default |
| 114 | + ) |
| 115 | + def test_processing_service_with_disabled_pipelines(self): |
| 116 | + """ |
| 117 | + Test that the default processing service is created with all pipelines disabled |
| 118 | + if DEFAULT_PIPELINES_ENABLED is any empty list. |
| 119 | + """ |
| 120 | + project = Project.objects.create(name="Test Project for Processing Service", create_defaults=True) |
| 121 | + processing_service = project.processing_services.first() |
| 122 | + assert processing_service is not None |
| 123 | + # There should be at least two pipelines created by default |
| 124 | + self.assertGreaterEqual(processing_service.pipelines.count(), 2) |
| 125 | + # All pipelines should be disabled by default |
| 126 | + project_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project) |
| 127 | + for config in project_pipeline_configs: |
| 128 | + self.assertFalse( |
| 129 | + config.enabled, |
| 130 | + f"Pipeline {config.pipeline.name} should be disabled for project {project.name}.", |
| 131 | + ) |
| 132 | + |
| 133 | + @override_settings( |
| 134 | + DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service", |
| 135 | + DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/", |
| 136 | + DEFAULT_PIPELINES_ENABLED=None, # All pipelines ENABLED by default |
| 137 | + ) |
| 138 | + def test_processing_service_with_enabled_pipelines(self): |
| 139 | + """ |
| 140 | + Test that the default processing service is created with all pipelines enabled |
| 141 | + if the DEFAULT_PIPELINES_ENABLED setting is None (or missing). |
| 142 | + """ |
| 143 | + project = Project.objects.create(name="Test Project for Processing Service", create_defaults=True) |
| 144 | + processing_service = project.processing_services.first() |
| 145 | + assert processing_service is not None |
| 146 | + # There should be at least two pipelines created by default |
| 147 | + self.assertGreaterEqual(processing_service.pipelines.count(), 2) |
| 148 | + # All pipelines should be enabled by default |
| 149 | + project_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project) |
| 150 | + for config in project_pipeline_configs: |
| 151 | + self.assertTrue( |
| 152 | + config.enabled, |
| 153 | + f"Pipeline {config.pipeline.name} should be enabled for project {project.name}.", |
| 154 | + ) |
| 155 | + |
| 156 | + @override_settings( |
| 157 | + DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service", |
| 158 | + DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/", # should have at least two pipelines |
| 159 | + DEFAULT_PIPELINES_ENABLED=["constant"], |
| 160 | + ) |
| 161 | + def test_existing_processing_service_new_project(self): |
| 162 | + """ |
| 163 | + Create a new project, enable all pipelines. |
| 164 | + Create a 2nd project, ensure that the same processing service is used and only the enabled pipelines are |
| 165 | + registered. |
| 166 | + """ |
| 167 | + enabled_pipelines = ["constant"] |
| 168 | + |
| 169 | + project_one = Project.objects.create(name="Test Project One", create_defaults=True) |
| 170 | + |
| 171 | + # Enable all pipelines for the first project |
| 172 | + ProjectPipelineConfig.objects.filter(project=project_one).update(enabled=True) |
| 173 | + |
| 174 | + project_two = Project.objects.create(name="Test Project Two", create_defaults=True) |
| 175 | + |
| 176 | + project_one_processing_service = project_one.processing_services.first() |
| 177 | + project_two_processing_service = project_two.processing_services.first() |
| 178 | + |
| 179 | + assert project_one_processing_service is not None |
| 180 | + assert project_two_processing_service is not None |
| 181 | + |
| 182 | + # Ensure only the same processing service instance is used (and they are not None) |
| 183 | + self.assertEqual( |
| 184 | + project_one_processing_service, |
| 185 | + project_two_processing_service, |
| 186 | + "Both projects should use the same processing service instance.", |
| 187 | + ) |
| 188 | + |
| 189 | + # Ensure that only the enabled pipelines are enabled for the second project |
| 190 | + project_two_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project_two) |
| 191 | + self.assertGreaterEqual(project_two_pipeline_configs.count(), 2, "Project should have at least two pipelines.") |
| 192 | + for config in project_two_pipeline_configs: |
| 193 | + if config.pipeline.slug in enabled_pipelines: |
| 194 | + self.assertTrue( |
| 195 | + config.enabled, |
| 196 | + f"Pipeline {config.pipeline.name} should be enabled for project {project_two.name}.", |
| 197 | + ) |
| 198 | + else: |
| 199 | + self.assertFalse( |
| 200 | + config.enabled, |
| 201 | + f"Pipeline {config.pipeline.name} should not be enabled for project {project_two.name}.", |
| 202 | + ) |
| 203 | + |
| 204 | + |
40 | 205 | class TestImageGrouping(TestCase):
|
41 | 206 | def setUp(self) -> None:
|
42 | 207 | print(f"Currently active database: {connection.settings_dict}")
|
|
0 commit comments