Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions backend/app/api/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,75 +5,93 @@
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.openapi.models import Example

from app.api.dependencies import get_model_id, get_model_service
from app.schemas import Model
from app.api.dependencies import get_model_id, get_model_service, get_project_id
from app.schemas import Label, Model
from app.services import ModelService, ResourceInUseError, ResourceNotFoundError

router = APIRouter(prefix="/api/models", tags=["Models"])


UPDATE_MODEL_BODY_EXAMPLES = {
"rename_model": Example(
summary="Rename model",
description="Change the name of the model",
value={
"name": "New Model Name",
},
)
}
router = APIRouter(prefix="/api/projects/{project_id}/models", tags=["Models"])


@router.get(
"",
response_model=list[Model],
responses={
status.HTTP_200_OK: {"description": "List of available models", "model": list[Model]},
status.HTTP_200_OK: {"description": "List of available models"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project ID"},
status.HTTP_404_NOT_FOUND: {"description": "Project not found"},
},
)
async def list_models(model_service: Annotated[ModelService, Depends(get_model_service)]) -> list[Model]:
"""Get information about available models"""
return model_service.list_models()
def list_models(
project_id: Annotated[UUID, Depends(get_project_id)],
model_service: Annotated[ModelService, Depends(get_model_service)],
) -> list[Model]:
"""Get all models in a project."""
try:
return model_service.list_models(project_id)
except ResourceNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")


@router.get(
"/{model_id}",
response_model=Model,
responses={
status.HTTP_200_OK: {"description": "Model found", "model": Model},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid model ID"},
status.HTTP_404_NOT_FOUND: {"description": "Model not found"},
status.HTTP_200_OK: {"description": "Model found"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project or model ID"},
status.HTTP_404_NOT_FOUND: {"description": "Project or model not found"},
},
)
async def get_model(
def get_model(
project_id: Annotated[UUID, Depends(get_project_id)],
model_id: Annotated[UUID, Depends(get_model_id)],
model_service: Annotated[ModelService, Depends(get_model_service)],
) -> Model:
"""Get information about a specific model"""
"""Get a specific model by ID."""
try:
return model_service.get_model_by_id(model_id)
return model_service.get_model_by_id(project_id, model_id)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))


@router.get(
"/{model_id}/labels",
responses={
status.HTTP_200_OK: {"description": "Model labels found"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project or model ID"},
status.HTTP_404_NOT_FOUND: {"description": "Project or model not found"},
},
)
def get_model_labels(
project_id: Annotated[UUID, Depends(get_project_id)],
model_id: Annotated[UUID, Depends(get_model_id)],
model_service: Annotated[ModelService, Depends(get_model_service)],
) -> list[Label]:
"""Get labels for a specific model."""
_ = project_id, model_id, model_service
raise NotImplementedError("Model labels endpoint is not implemented yet")


@router.delete(
"/{model_id}",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {
"description": "Model configuration successfully deleted",
},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid model ID"},
status.HTTP_404_NOT_FOUND: {"description": "Model not found"},
status.HTTP_400_BAD_REQUEST: {"description": "Invalid project or model ID"},
status.HTTP_404_NOT_FOUND: {"description": "Project or model not found"},
status.HTTP_409_CONFLICT: {"description": "Model is used by at least one pipeline"},
},
)
async def delete_model(
def delete_model(
project_id: Annotated[UUID, Depends(get_project_id)],
model_id: Annotated[UUID, Depends(get_model_id)],
model_service: Annotated[ModelService, Depends(get_model_service)],
) -> None:
"""Delete a model"""
"""Delete a model from a project."""
try:
model_service.delete_model_by_id(model_id)
model_service.delete_model_by_id(project_id, model_id)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except ResourceInUseError as e:
Expand Down
1 change: 1 addition & 0 deletions backend/app/db/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ProjectDB(Base):

pipeline = relationship("PipelineDB", back_populates="project", uselist=False)
labels = relationship("LabelDB", back_populates="project")
model_revisions = relationship("ModelRevisionDB", backref="project")


class PipelineDB(Base):
Expand Down
4 changes: 4 additions & 0 deletions backend/app/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import TypeVar

from sqlalchemy import exists
from sqlalchemy.orm import Session

from app.db.schema import Base
Expand All @@ -21,6 +22,9 @@ def __init__(self, db: Session, model: type[ModelType]) -> None:
def get_by_id(self, obj_id: str) -> ModelType | None:
return self.db.get(self.model, obj_id)

def exists(self, obj_id: str) -> bool:
return self.db.query(exists().where(self.model.id == obj_id)).scalar() # type: ignore[attr-defined]

def list_all(self) -> list[ModelType]:
return self.db.query(self.model).all()

Expand Down
2 changes: 1 addition & 1 deletion backend/app/repositories/label_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class LabelRepository(BaseRepository[LabelDB]):
"""Repository for label-related database operations."""

def __init__(self, db: Session, project_id: str):
def __init__(self, project_id: str, db: Session):
super().__init__(db, LabelDB)
self.project_id = project_id

Expand Down
2 changes: 1 addition & 1 deletion backend/app/services/label_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def update_labels_in_project(
"""
try:
with get_db_session() as db:
label_repo = LabelRepository(db, str(project_id))
label_repo = LabelRepository(project_id=str(project_id), db=db)
if labels_to_update:
label_repo.update_batch(_convert_labels_to_db(labels_to_update, project_id))
if label_ids_to_remove:
Expand Down
105 changes: 87 additions & 18 deletions backend/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from uuid import UUID

from model_api.models import Model
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError

from app.db import get_db_session
from app.repositories import ModelRevisionRepository
from app.repositories import ModelRevisionRepository, ProjectRepository
from app.schemas.model import Model as ModelSchema
from app.schemas.model_activation import ModelActivationState
from app.services.base import GenericPersistenceService, ResourceNotFoundError, ResourceType, ServiceConfig
from app.services.base import ResourceInUseError, ResourceNotFoundError, ResourceType
from app.services.mappers.model_revision_mapper import ModelRevisionMapper
from app.services.parent_process_guard import parent_process_only

Expand All @@ -37,11 +37,9 @@ class ModelService:
def __init__(self, data_dir: Path, mp_model_reload_event: EventClass | None = None) -> None:
self.models_dir = data_dir / "models"
self._mp_model_reload_event = mp_model_reload_event
self._persistence: GenericPersistenceService[Model, ModelRevisionRepository] = GenericPersistenceService(
ServiceConfig(ModelRevisionRepository, ModelRevisionMapper, ResourceType.MODEL)
)
self._model_activation_state: ModelActivationState = self._load_state()
self._loaded_model: LoadedModel | None = None
self._mapper = ModelRevisionMapper()

@staticmethod
def _load_state() -> ModelActivationState:
Expand Down Expand Up @@ -92,19 +90,90 @@ def get_loaded_inference_model(self, force_reload: bool = False) -> LoadedModel
)
return self._loaded_model

def get_model_by_id(self, model_id: UUID, db: Session | None = None) -> ModelSchema:
"""Get a model by its ID"""
model = self._persistence.get_by_id(model_id, db)
if not model:
raise ResourceNotFoundError(ResourceType.MODEL, str(model_id))
return model
def get_model_by_id(self, project_id: UUID, model_id: UUID) -> ModelSchema:
"""
Get a model by its ID within a specific project.

Retrieves a model revision from the specified project by matching the model ID.
The method first validates that the project exists, then searches through the
project's model revisions to find the one with the matching ID.

Args:
project_id (UUID): The unique identifier of the project containing the model.
model_id (UUID): The unique identifier of the model to retrieve.

Returns:
ModelSchema: The model schema object containing the model's information.

Raises:
ResourceNotFoundError: If the project with the given project_id does not exist,
or if no model with the given model_id is found within the project.
"""
with get_db_session() as db:
project_repo = ProjectRepository(db)
project = project_repo.get_by_id(str(project_id))
if not project:
raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id))
model = next((self._mapper.to_schema(m) for m in project.model_revisions if m.id == str(model_id)), None)
if not model:
raise ResourceNotFoundError(ResourceType.MODEL, str(model_id))
return model

@parent_process_only
def delete_model_by_id(self, model_id: UUID) -> None:
"""Delete a model by its ID"""
def delete_model_by_id(self, project_id: UUID, model_id: UUID) -> None:
"""
Delete a model by its ID from a specific project.

Permanently removes a model revision from the specified project. The method
first validates that the project exists, then attempts to delete the model
from the database. This operation is restricted to the parent process only.

Args:
project_id (UUID): The unique identifier of the project containing the model.
model_id (UUID): The unique identifier of the model to delete.

Returns:
None

Raises:
ResourceNotFoundError: If the project with the given project_id does not exist,
or if no model with the given model_id is found.
ResourceInUseError: If the model cannot be deleted due to integrity constraints
(e.g., the model is referenced by other entities).
"""
with get_db_session() as db:
self._persistence.delete_by_id(model_id, db)
project_repo = ProjectRepository(db)
if not project_repo.exists(str(project_id)):
raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id))
model_rev_repo = ModelRevisionRepository(db)
try:
deleted = model_rev_repo.delete(str(model_id))
if not deleted:
raise ResourceNotFoundError(ResourceType.MODEL, str(model_id))
except IntegrityError:
raise ResourceInUseError(ResourceType.MODEL, str(model_id))

def list_models(self, project_id: UUID) -> list[ModelSchema]:
"""
Get information about all available model revisions in a project.

Retrieves a list of all model revisions that belong to the specified project.
Each model revision is converted to a schema object containing the model's
metadata and configuration information.

def list_models(self) -> list[ModelSchema]:
"""Get information about available models"""
return self._persistence.list_all()
Args:
project_id (UUID): The unique identifier of the project whose models to list.

Returns:
list[ModelSchema]: A list of model schema objects representing all model
revisions in the project. Returns an empty list if the project has no models.

Raises:
ResourceNotFoundError: If the project with the given project_id does not exist.
"""
with get_db_session() as db:
project_repo = ProjectRepository(db)
project = project_repo.get_by_id(str(project_id))
if not project:
raise ResourceNotFoundError(ResourceType.PROJECT, str(project_id))
return [self._mapper.to_schema(model_rev_db) for model_rev_db in project.model_revisions]
50 changes: 30 additions & 20 deletions backend/tests/integration/services/test_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,47 +61,57 @@ class TestModelServiceIntegration:

def test_list_models(self, fxt_db_projects, fxt_db_models, fxt_model_service, db_session):
"""Test retrieving all models."""
create_model_db(fxt_db_projects[0], fxt_db_models, db_session)
db_project = fxt_db_projects[0]
create_model_db(db_project, fxt_db_models, db_session)

models = fxt_model_service.list_models()
models = fxt_model_service.list_models(UUID(db_project.id))

assert len(models) == len(fxt_db_models)
for i, model in enumerate(models):
assert_model(model, fxt_db_models[i])

@pytest.mark.parametrize("model_operation", ["list_models", "get_model_by_id", "delete_model_by_id"])
def test_model_with_non_existent_project(self, model_operation, fxt_model_service):
"""Test deleting a model from non-existent project raises error."""
project_id, model_id = uuid4(), uuid4()
with pytest.raises(ResourceNotFoundError) as excinfo:
if model_operation == "list_models":
getattr(fxt_model_service, model_operation)(project_id)
else:
getattr(fxt_model_service, model_operation)(project_id, model_id)

assert excinfo.value.resource_type == ResourceType.PROJECT
assert excinfo.value.resource_id == str(project_id)

def test_get_model(self, fxt_db_projects, fxt_db_models, fxt_model_service, db_session):
"""Test retrieving a model by ID."""
db_model = fxt_db_models[0]
create_model_db(fxt_db_projects[0], [db_model], db_session)
db_project, db_model = fxt_db_projects[0], fxt_db_models[0]
create_model_db(db_project, [db_model], db_session)

model = fxt_model_service.get_model_by_id(UUID(db_model.id))
model = fxt_model_service.get_model_by_id(UUID(db_project.id), UUID(db_model.id))

assert model is not None
assert_model(model, db_model)

def test_get_non_existent_model(self, fxt_model_service):
@pytest.mark.parametrize("model_operation", ["get_model_by_id", "delete_model_by_id"])
def test_non_existent_model(self, model_operation, fxt_db_projects, fxt_model_service, db_session):
"""Test retrieving a non-existent model raises error."""
model_id = uuid4()
db_project = fxt_db_projects[0]
db_session.add(db_project)
db_session.flush()

project_id, model_id = UUID(db_project.id), uuid4()
with pytest.raises(ResourceNotFoundError) as excinfo:
fxt_model_service.get_model_by_id(model_id)
getattr(fxt_model_service, model_operation)(project_id, model_id)

assert excinfo.value.resource_type == ResourceType.MODEL
assert excinfo.value.resource_id == str(model_id)

def test_delete_model(self, fxt_db_projects, fxt_db_models, fxt_model_service, db_session):
"""Test deleting a model by ID."""
db_model = fxt_db_models[0]
create_model_db(fxt_db_projects[0], [db_model], db_session)
db_project, db_model = fxt_db_projects[0], fxt_db_models[0]
create_model_db(db_project, [db_model], db_session)

fxt_model_service.delete_model_by_id(UUID(db_model.id))
fxt_model_service.delete_model_by_id(UUID(db_project.id), UUID(db_model.id))

assert db_session.query(ModelRevisionDB).count() == 0

def test_delete_non_existent_model(self, fxt_model_service):
"""Test deleting a non-existent model raises error."""
model_id = uuid4()
with pytest.raises(ResourceNotFoundError) as excinfo:
fxt_model_service.delete_model_by_id(model_id)

assert excinfo.value.resource_type == ResourceType.MODEL
assert excinfo.value.resource_id == str(model_id)
Loading
Loading