Skip to content

Commit 8ad8b57

Browse files
committed
Add test fixture
1 parent fa6579a commit 8ad8b57

File tree

5 files changed

+42
-44
lines changed

5 files changed

+42
-44
lines changed

ami/ml/models/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .algorithm import Algorithm
2-
from .backend import Backend
3-
from .pipeline import Pipeline
1+
from ami.ml.models.algorithm import Algorithm
2+
from ami.ml.models.backend import Backend
3+
from ami.ml.models.pipeline import Pipeline
44

55
__all__ = [
66
"Algorithm",

ami/ml/models/backend.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
from ami.base.models import BaseModel
1010
from ami.main.models import Project
11+
from ami.ml.models.algorithm import Algorithm
12+
from ami.ml.models.pipeline import Pipeline
1113
from ami.ml.schemas import BackendStatusResponse, PipelineRegistrationResponse
1214

13-
from .algorithm import Algorithm
14-
from .pipeline import Pipeline
15-
1615
logger = logging.getLogger(__name__)
1716

1817

@@ -22,7 +21,7 @@ class Backend(BaseModel):
2221

2322
projects = models.ManyToManyField("main.Project", related_name="backends", blank=True)
2423
endpoint_url = models.CharField(max_length=1024, null=True, blank=True)
25-
pipelines = models.ManyToManyField(Pipeline, related_name="backends", blank=True)
24+
pipelines = models.ManyToManyField("ml.Pipeline", related_name="backends", blank=True)
2625

2726
def __str__(self):
2827
return self.endpoint_url

ami/ml/models/pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030
TaxonRank,
3131
update_calculated_fields_for_events,
3232
)
33+
from ami.ml.models.algorithm import Algorithm
3334
from ami.ml.schemas import PipelineRequest, PipelineResponse, SourceImageRequest
3435
from ami.ml.tasks import celery_app, create_detection_images
3536

36-
from .algorithm import Algorithm
37-
3837
logger = logging.getLogger(__name__)
3938

4039

@@ -414,7 +413,7 @@ class Pipeline(BaseModel):
414413
version = models.IntegerField(default=1)
415414
version_name = models.CharField(max_length=255, blank=True)
416415
# @TODO the algorithms list be retrieved by querying the pipeline endpoint
417-
algorithms = models.ManyToManyField(Algorithm, related_name="pipelines")
416+
algorithms = models.ManyToManyField("ml.Algorithm", related_name="pipelines")
418417
stages: list[PipelineStage] = SchemaField(
419418
default=default_stages,
420419
help_text=(

ami/ml/tests.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@
1313
PipelineResponse,
1414
SourceImageResponse,
1515
)
16-
from ami.tests.fixtures.main import create_captures_from_files, create_ml_pipeline, setup_test_project
16+
from ami.tests.fixtures.main import create_captures_from_files, create_ml_backends, setup_test_project
1717

1818

1919
class TestPipelineWithMLBackend(TestCase):
2020
def setUp(self):
2121
self.project, self.deployment = setup_test_project()
2222
self.captures = create_captures_from_files(self.deployment, skip_existing=False)
2323
self.test_images = [image for image, frame in self.captures]
24-
self.pipeline = create_ml_pipeline(self.project)
24+
self.backend_instance = create_ml_backends(self.project)
25+
self.backend = self.backend_instance
26+
# @TODO: Create function to get most recent OK backend
27+
self.pipeline = self.backend_instance.pipelines.all().filter(slug="constant").first()
28+
self.backend_id = self.pipeline.backends.first().pk
29+
# @TODO: Add error or info messages to the response if image already processed or no detections returned
2530

2631
def test_run_pipeline(self):
2732
# Send images to ML backend to process and return detections
28-
pipeline_response = self.pipeline.process_images(self.test_images)
33+
pipeline_response = self.pipeline.process_images(self.test_images, backend_id=self.backend_id, job_id=None)
2934
assert pipeline_response.detections
3035

3136

ami/tests/fixtures/main.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
TaxonRank,
1818
group_images_into_events,
1919
)
20+
from ami.ml.models.backend import Backend
2021
from ami.ml.tasks import create_detection_images
2122
from ami.tests.fixtures.storage import GeneratedTestFrame, create_storage_source, populate_bucket
2223

@@ -33,62 +34,56 @@ def update_site_settings(**kwargs):
3334
return site
3435

3536

36-
def create_ml_pipeline(project):
37-
from ami.ml.models import Algorithm, Pipeline
38-
39-
pipelines_to_add = [
37+
# @TODO: To test this: delete project in admin, then run migrate
38+
# (this will execute the signal in ami-platform/ami/tests/fixtures/signals.py)
39+
def create_ml_backends(project):
40+
backends_to_add = [
4041
{
41-
"name": "ML Dummy Backend",
42-
"slug": "dummy",
43-
"version": 1,
44-
"algorithms": [
45-
{"name": "Dummy Detector", "key": 1},
46-
{"name": "Random Detector", "key": 2},
47-
{"name": "Always Moth Classifier", "key": 3},
48-
],
49-
"projects": {"name": project.name},
50-
"endpoint_url": "http://ml_backend:2000/pipeline/process",
42+
"projects": [{"name": project.name}],
43+
"endpoint_url": "http://ml_backend:2000",
5144
},
5245
]
5346

54-
for pipeline_data in pipelines_to_add:
55-
pipeline, created = Pipeline.objects.get_or_create(
56-
name=pipeline_data["name"],
57-
slug=pipeline_data["slug"],
58-
version=pipeline_data["version"],
59-
endpoint_url=pipeline_data["endpoint_url"],
47+
for backend_data in backends_to_add:
48+
backend, created = Backend.objects.get_or_create(
49+
endpoint_url=backend_data["endpoint_url"],
6050
)
6151

6252
if created:
63-
logger.info(f'Successfully created {pipeline_data["name"]}.')
53+
logger.info(f'Successfully created backend with {backend_data["endpoint_url"]}.')
6454
else:
65-
logger.info(f'Using existing pipeline {pipeline_data["name"]}.')
55+
logger.info(f'Using existing backend with {backend_data["endpoint_url"]}.')
56+
57+
for project_data in backend_data["projects"]:
58+
try:
59+
project = Project.objects.get(name=project_data["name"])
60+
backend.projects.add(project)
61+
except Exception:
62+
logger.error(f'Could not find project {project_data["name"]}.')
6663

67-
for algorithm_data in pipeline_data["algorithms"]:
68-
algorithm, _ = Algorithm.objects.get_or_create(name=algorithm_data["name"], key=algorithm_data["key"])
69-
pipeline.algorithms.add(algorithm)
64+
backend.save()
7065

71-
pipeline.save()
66+
backend.create_pipelines()
7267

73-
return pipeline
68+
return backend
7469

7570

7671
def setup_test_project(reuse=True) -> tuple[Project, Deployment]:
72+
short_id = "1ed10463"
7773
if reuse:
78-
project, _ = Project.objects.get_or_create(name="Test Project")
74+
project, _ = Project.objects.get_or_create(name=f"Test Project {short_id}")
7975
data_source = create_storage_source(project, "Test Data Source")
8076
deployment, _ = Deployment.objects.get_or_create(
8177
project=project, name="Test Deployment", defaults=dict(data_source=data_source)
8278
)
83-
create_ml_pipeline(project)
79+
create_ml_backends(project)
8480
else:
85-
short_id = uuid.uuid4().hex[:8]
8681
project = Project.objects.create(name=f"Test Project {short_id}")
8782
data_source = create_storage_source(project, f"Test Data Source {short_id}")
8883
deployment = Deployment.objects.create(
8984
project=project, name=f"Test Deployment {short_id}", data_source=data_source
9085
)
91-
create_ml_pipeline(project)
86+
create_ml_backends(project)
9287
return project, deployment
9388

9489

0 commit comments

Comments
 (0)