Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ami/jobs/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from django.utils.translation import gettext_lazy as _


class UsersConfig(AppConfig):
class JobsConfig(AppConfig):
name = "ami.jobs"
verbose_name = _("Jobs")
11 changes: 5 additions & 6 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,9 @@ def run(cls, job: "Job"):
total_detections = 0
total_classifications = 0

# Set to low size because our response JSON just got enormous
# @TODO make this configurable
CHUNK_SIZE = 1
chunks = [images[i : i + CHUNK_SIZE] for i in range(0, image_count, CHUNK_SIZE)] # noqa
config = job.pipeline.get_config(project_id=job.project.pk)
chunk_size = config.get("request_source_image_batch_size", 1)
chunks = [images[i : i + chunk_size] for i in range(0, image_count, chunk_size)] # noqa
request_failed_images = []

for i, chunk in enumerate(chunks):
Expand Down Expand Up @@ -434,9 +433,9 @@ def run(cls, job: "Job"):
"process",
status=JobState.STARTED,
progress=(i + 1) / len(chunks),
processed=min((i + 1) * CHUNK_SIZE, image_count),
processed=min((i + 1) * chunk_size, image_count),
failed=len(request_failed_images),
remaining=max(image_count - ((i + 1) * CHUNK_SIZE), 0),
remaining=max(image_count - ((i + 1) * chunk_size), 0),
)

# count the completed, successful, and failed save_tasks:
Expand Down
2 changes: 1 addition & 1 deletion ami/labelstudio/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from django.utils.translation import gettext_lazy as _


class UsersConfig(AppConfig):
class LabelStudioConfig(AppConfig):
name = "ami.labelstudio"
verbose_name = _("Label Studio Integration")
2 changes: 1 addition & 1 deletion ami/ml/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from django.utils.translation import gettext_lazy as _


class UsersConfig(AppConfig):
class MLConfig(AppConfig):
name = "ami.ml"
verbose_name = _("Machine Learning")
24 changes: 24 additions & 0 deletions ami/ml/migrations/0021_pipeline_default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 4.2.10 on 2025-03-19 16:27

import ami.ml.schemas
from django.db import migrations
import django_pydantic_field.fields


class Migration(migrations.Migration):
dependencies = [
("ml", "0020_projectpipelineconfig_alter_pipeline_projects"),
]

operations = [
migrations.AddField(
model_name="pipeline",
name="default_config",
field=django_pydantic_field.fields.PydanticSchemaField(
config=None,
default=dict,
help_text="The default configuration for the pipeline. Used by both the job sending images to the pipeline and the processing service.",
schema=ami.ml.schemas.PipelineRequestConfigParameters,
),
),
]
71 changes: 41 additions & 30 deletions ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ami.ml.models import ProcessingService # , ProjectPipelineConfig
from ami.ml.models import ProcessingService, ProjectPipelineConfig
from ami.jobs.models import Job

import collections
import dataclasses
Expand Down Expand Up @@ -40,6 +41,7 @@
ClassificationResponse,
DetectionResponse,
PipelineRequest,
PipelineRequestConfigParameters,
PipelineResultsResponse,
SourceImageRequest,
SourceImageResponse,
Expand Down Expand Up @@ -98,7 +100,7 @@ def filter_processed_images(
)
# log all algorithms that are in the pipeline but not in the detection
missing_algos = pipeline_algorithm_ids - detection_algorithm_ids
task_logger.info(f"Image #{image.pk} needs classification by pipeline's algorithms: {missing_algos}")
task_logger.debug(f"Image #{image.pk} needs classification by pipeline's algorithms: {missing_algos}")
yield image
else:
# If all detections have been classified by the pipeline, skip the image
Expand Down Expand Up @@ -162,9 +164,6 @@ def process_images(
) -> PipelineResultsResponse:
"""
Process images using ML pipeline API.

@TODO find a home for this function.
@TODO break into task chunks.
"""
job = None
task_logger = logger
Expand Down Expand Up @@ -201,29 +200,11 @@ def process_images(
if url
]

if project_id:
try:
config = pipeline.project_pipeline_configs.get(project_id=project_id).config
task_logger.info(
f"Sending pipeline request using {config} from the project-pipeline config "
f"for Pipeline {pipeline} and Project id {project_id}."
)
except pipeline.project_pipeline_configs.model.DoesNotExist as e:
task_logger.error(
f"Error getting the project-pipeline config for Pipeline {pipeline} "
f"and Project id {project_id}: {e}"
)
config = {}
task_logger.info(
"Using empty config when sending pipeline request since no project-pipeline config "
f"was found for Pipeline {pipeline} and Project id {project_id}"
)
else:
config = {}
task_logger.info(
"Using empty config when sending pipeline request "
f"since no project id was provided for Pipeline {pipeline}"
)
if not project_id:
task_logger.warning(f"Pipeline {pipeline} is not associated with a project")

config = pipeline.get_config(project_id=project_id)
task_logger.info(f"Using pipeline config: {config}")

request_data = PipelineRequest(
pipeline=pipeline.slug,
Expand Down Expand Up @@ -914,7 +895,7 @@ class Pipeline(BaseModel):
description = models.TextField(blank=True)
version = models.IntegerField(default=1)
version_name = models.CharField(max_length=255, blank=True)
# @TODO the algorithms list be retrieved by querying the pipeline endpoint
# @TODO the algorithms attribute is not currently used. Review for removal.
algorithms = models.ManyToManyField("ml.Algorithm", related_name="pipelines")
stages: list[PipelineStage] = SchemaField(
default=default_stages,
Expand All @@ -926,8 +907,18 @@ class Pipeline(BaseModel):
projects = models.ManyToManyField(
"main.Project", related_name="pipelines", blank=True, through="ml.ProjectPipelineConfig"
)
default_config: PipelineRequestConfigParameters = SchemaField(
schema=PipelineRequestConfigParameters,
default=dict,
help_text=(
"The default configuration for the pipeline. "
"Used by both the job sending images to the pipeline "
"and the processing service."
),
)
processing_services: models.QuerySet[ProcessingService]
# project_pipeline_configs: models.QuerySet[ProjectPipelineConfig]
project_pipeline_configs: models.QuerySet[ProjectPipelineConfig]
jobs: models.QuerySet[Job]

class Meta:
ordering = ["name", "version"]
Expand All @@ -939,6 +930,26 @@ class Meta:
def __str__(self):
return f'#{self.pk} "{self.name}" ({self.slug}) v{self.version}'

def get_config(self, project_id: int | None = None) -> PipelineRequestConfigParameters:
"""
Get the configuration for the pipeline request.

This will be the same as pipeline.default_config, but if a project ID is provided,
the project's pipeline config will be used to override the default config.
"""
config = self.default_config
if project_id:
try:
project_pipeline_config = self.project_pipeline_configs.get(project_id=project_id)
if project_pipeline_config.config:
config.update(project_pipeline_config.config)
logger.debug(
f"Using ProjectPipelineConfig for Pipeline {self} and Project #{project_id}:" f"config: {config}"
)
except self.project_pipeline_configs.model.DoesNotExist as e:
logger.warning(f"No project-pipeline config for Pipeline {self} " f"and Project #{project_id}: {e}")
return config

def collect_images(
self,
collection: SourceImageCollection | None = None,
Expand Down
8 changes: 4 additions & 4 deletions ami/ml/models/processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def create_pipelines(self):
algorithms_created=algorithms_created,
)

def get_status(self):
def get_status(self, timeout=6):
"""
Check the status of the processing service.
This is a simple health check that pings the /readyz endpoint of the service.
Expand All @@ -116,7 +116,7 @@ def get_status(self):
resp = None

try:
resp = requests.get(ready_check_url)
resp = requests.get(ready_check_url, timeout=timeout)
resp.raise_for_status()
self.last_checked_live = True
latency = time.time() - start_time
Expand Down Expand Up @@ -158,13 +158,13 @@ def get_status(self):

return response

def get_pipeline_configs(self):
def get_pipeline_configs(self, timeout=6):
"""
Get the pipeline configurations from the processing service.
This can be a long response as it includes the full category map for each algorithm.
"""
info_url = urljoin(self.endpoint_url, "info")
resp = requests.get(info_url)
resp = requests.get(info_url, timeout=timeout)
resp.raise_for_status()
info_data = ProcessingServiceInfoResponse.parse_obj(resp.json())
return info_data.pipelines
21 changes: 20 additions & 1 deletion ami/ml/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,29 @@ class Config:
]


class PipelineRequestConfigParameters(dict):
"""Parameters used to configure a pipeline request.

Accepts any serializable key-value pair.
Example: {"force_reprocess": True, "auth_token": "abc123"}

Supported parameters are defined by the pipeline in the processing service
and should be published in the Pipeline's info response.

Parameters that are used by Antenna before sending the request to the Processing Service
should be prefixed with "request_".
Example: {"request_source_image_batch_size": 8}
Such parameters need to be ignored by the schema in the Processing Service, or
removed before sending the request to the Processing Service.
"""

pass


class PipelineRequest(pydantic.BaseModel):
pipeline: str
source_images: list[SourceImageRequest]
config: dict
config: PipelineRequestConfigParameters | dict | None = None


class PipelineResultsResponse(pydantic.BaseModel):
Expand Down
24 changes: 24 additions & 0 deletions ami/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,30 @@ def test_yes_reprocess_if_new_terminal_algorithm_same_intermediate(self):
remaining_images_to_process = len(images_again)
self.assertEqual(remaining_images_to_process, len(images), "Images not re-processed with new pipeline")

def test_project_pipeline_config(self):
"""
Test the default_config for a pipeline, as well as the project pipeline config.
Ensure the project pipeline parameters override the pipeline defaults.
"""
from ami.ml.models import ProjectPipelineConfig
from ami.ml.schemas import PipelineRequestConfigParameters

# Add config to the pipeline & project
self.pipeline.default_config = PipelineRequestConfigParameters({"test_param": "test_value"})
self.pipeline.save()
self.project_pipeline_config = ProjectPipelineConfig.objects.create(
project=self.project,
pipeline=self.pipeline,
config={"test_param": "project_value"},
)
self.project_pipeline_config.save()

# Check the final config
default_config = self.pipeline.get_config()
self.assertEqual(default_config["test_param"], "test_value")
final_config = self.pipeline.get_config(self.project.pk)
self.assertEqual(final_config["test_param"], "project_value")


class TestAlgorithmCategoryMaps(TestCase):
def setUp(self):
Expand Down
23 changes: 22 additions & 1 deletion processing_services/example/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,30 @@ class Config:
PipelineChoice = typing.Literal["random", "constant"]


class PipelineRequestConfigParameters(pydantic.BaseModel):
"""Parameters used to configure a pipeline request.

Accepts any serializable key-value pair.
Example: {"force_reprocess": True, "auth_token": "abc123"}

Supported parameters are defined by the pipeline in the processing service
and should be published in the Pipeline's info response.
"""

force_reprocess: bool = pydantic.Field(
default=False,
description="Force reprocessing of the image, even if it has already been processed.",
)
auth_token: str | None = pydantic.Field(
default=None,
description="An optional authentication token to use for the pipeline.",
)


class PipelineRequest(pydantic.BaseModel):
pipeline: PipelineChoice
source_images: list[SourceImageRequest]
config: dict
config: PipelineRequestConfigParameters | dict | None = None

# Example for API docs:
class Config:
Expand All @@ -203,6 +223,7 @@ class Config:
"url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg",
}
],
"config": {"force_reprocess": True, "auth_token": "abc123"},
}
}

Expand Down