Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ami/ml/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class PipelineNotConfigured(ValueError):
pass
79 changes: 58 additions & 21 deletions ami/ml/models/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -15,6 +16,18 @@
from ami.base.models import BaseModel


@typing.final
class AlgorithmCategoryMapManager(models.Manager["AlgorithmCategoryMap"]):
def create(self, *args, **kwargs):
"""
Create a new AlgorithmCategoryMap instance and generate its labels_hash.
"""
instance = super().create(*args, **kwargs)
instance.labels_hash = instance.make_labels_hash(instance.labels)
instance.save()
return instance


@typing.final
class AlgorithmCategoryMap(BaseModel):
"""
Expand Down Expand Up @@ -44,6 +57,8 @@ class AlgorithmCategoryMap(BaseModel):

algorithms: models.QuerySet[Algorithm]

objects = AlgorithmCategoryMapManager()

def __str__(self):
return f"#{self.pk} with {len(self.labels)} classes ({self.version or 'unknown version'})"

Expand Down Expand Up @@ -110,6 +125,31 @@ def with_category_count(self):
return self.annotate(category_count=ArrayLength("category_map__labels"))


# Task types enum for better type checking
class AlgorithmTaskType(str, enum.Enum):
DETECTION = "detection"
LOCALIZATION = "localization"
SEGMENTATION = "segmentation"
CLASSIFICATION = "classification"
EMBEDDING = "embedding"
TRACKING = "tracking"
TAGGING = "tagging"
REGRESSION = "regression"
CAPTIONING = "captioning"
GENERATION = "generation"
TRANSLATION = "translation"
SUMMARIZATION = "summarization"
QUESTION_ANSWERING = "question_answering"
DEPTH_ESTIMATION = "depth_estimation"
POSE_ESTIMATION = "pose_estimation"
SIZE_ESTIMATION = "size_estimation"
OTHER = "other"
UNKNOWN = "unknown"

def as_choice(self):
return (self.value, self.name.replace("_", " ").title())


@typing.final
class Algorithm(BaseModel):
"""A machine learning algorithm"""
Expand All @@ -120,28 +160,8 @@ class Algorithm(BaseModel):
max_length=255,
default="unknown",
null=True,
choices=[
("detection", "Detection"),
("localization", "Localization"),
("segmentation", "Segmentation"),
("classification", "Classification"),
("embedding", "Embedding"),
("tracking", "Tracking"),
("tagging", "Tagging"),
("regression", "Regression"),
("captioning", "Captioning"),
("generation", "Generation"),
("translation", "Translation"),
("summarization", "Summarization"),
("question_answering", "Question Answering"),
("depth_estimation", "Depth Estimation"),
("pose_estimation", "Pose Estimation"),
("size_estimation", "Size Estimation"),
("other", "Other"),
("unknown", "Unknown"),
],
choices=[task_type.as_choice() for task_type in AlgorithmTaskType],
)
detection_algorithm_task_types = ["detection", "localization", "segmentation"]
description = models.TextField(blank=True)
version = models.IntegerField(
default=1,
Expand Down Expand Up @@ -172,6 +192,16 @@ class Algorithm(BaseModel):

objects = AlgorithmQuerySet.as_manager()

detection_task_types = [
AlgorithmTaskType.DETECTION,
AlgorithmTaskType.LOCALIZATION,
AlgorithmTaskType.SEGMENTATION,
]
classification_task_types = [
AlgorithmTaskType.CLASSIFICATION,
AlgorithmTaskType.TAGGING,
]

def __str__(self):
return f'#{self.pk} "{self.name}" ({self.key}) v{self.version}'

Expand All @@ -197,3 +227,10 @@ def category_count(self) -> int | None:
but is defined here for the serializer to work.
"""
return None

def has_valid_category_map(self):
return (
(self.category_map is not None)
and (self.category_map.data is not None)
and (len(self.category_map.data) > 0)
)
Loading