Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e5287b1
fix: return no urls instead of broken urls
mihow Sep 5, 2024
c048cda
feat: add tests for s3 storage source
mihow Sep 5, 2024
4ef1cf9
fix: require source images to have a data store defined
mihow Sep 5, 2024
5ea2d66
feat: generate test source images
mihow Sep 6, 2024
5440178
feat: create occurrences & detections from test images
mihow Sep 6, 2024
a31a5e2
fix: use django storages for detection crops
mihow Sep 6, 2024
a111435
feat: methods for creating crops from bboxes from the platform
mihow Sep 6, 2024
a8951a3
fix: cachalot pinned version
mihow Sep 6, 2024
5483c07
fix: optionally raise errors when generating image public urls
mihow Sep 6, 2024
9ca2a27
fix: bounding boxes that overlap outside of source image frame
mihow Sep 6, 2024
11f225b
fix: continue passed failed images
mihow Sep 6, 2024
1d4b645
fix: get tests to pass
mihow Sep 6, 2024
51a9ebf
Merge branch 'main' into fix/missing-crops
mihow Sep 9, 2024
988f998
Merge branch 'main' into fix/missing-crops
mihow Sep 10, 2024
47154b7
chore: rename detection cropping methods
mihow Sep 10, 2024
56cd8d0
feat: create cropped detection images after processing
mihow Sep 11, 2024
97561e6
chore: rename management command
mihow Sep 11, 2024
ae1c699
fix: filter duplicate classifications on the platform side
mihow Sep 11, 2024
c728e3e
docs: describe what the detection path field is intended for
mihow Sep 11, 2024
25b6cb4
fix: correctly save cropped url path from ml backend, if one is returned
mihow Sep 11, 2024
a9063aa
fix: allow main result-saving task to complete
mihow Sep 11, 2024
2e6423d
Merge branch 'main' into fix/missing-crops
mihow Sep 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .envs/.ci/.django
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ DJANGO_CSRF_TRUSTED_ORIGINS=http://localhost:3000,
MINIO_ENDPOINT=http://minio:9000
MINIO_ROOT_USER=amistorage
MINIO_ROOT_PASSWORD=amistorage
MINIO_DEFAULT_BUCKET=ami
MINIO_DEFAULT_BUCKET=ami-ci
MINIO_STORAGE_USE_HTTPS=False
MINIO_TEST_BUCKET=ami-test
MINIO_TEST_BUCKET=ami-test-ci
MINIO_BROWSER_REDIRECT_URL=http://minio:9001
57 changes: 42 additions & 15 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from django.apps import apps
from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.files.storage import default_storage
from django.db import IntegrityError, models
from django.db.models import Q
from django.db.models.fields.files import ImageFieldFile
Expand Down Expand Up @@ -62,11 +63,6 @@ class TaxonRank(OrderedEnum):
)


# @TODO move to settings & make configurable
_SOURCE_IMAGES_URL_BASE = "https://static.dev.insectai.org/ami-trapdata/vermont/snapshots/"
_CROPS_URL_BASE = "https://static.dev.insectai.org/ami-trapdata/crops"


def get_media_url(path: str) -> str:
"""
If path is a full URL, return it as-is.
Expand All @@ -77,7 +73,8 @@ def get_media_url(path: str) -> str:
if path.startswith("http"):
url = path
else:
url = urllib.parse.urljoin(_CROPS_URL_BASE, path.lstrip("/"))
# @TODO add a file field to the Detection model and use that to get the URL
url = default_storage.url(path.lstrip("/"))
return url


Expand Down Expand Up @@ -1172,7 +1169,7 @@ class SourceImage(BaseModel):
def __str__(self) -> str:
return f"{self.__class__.__name__} #{self.pk} {self.path}"

def public_url(self) -> str:
def public_url(self, raise_errors=False) -> str | None:
"""
Return the public URL for this image.

Expand All @@ -1192,26 +1189,45 @@ def public_url(self) -> str:
and data_source.access_key
and data_source.secret_key
):
return ami.utils.s3.get_presigned_url(data_source.config, key=self.path)
url = ami.utils.s3.get_presigned_url(data_source.config, key=self.path)
elif self.public_base_url:
url = urllib.parse.urljoin(self.public_base_url, self.path.lstrip("/"))
else:
return urllib.parse.urljoin(self.public_base_url or "/", self.path.lstrip("/"))
msg = f"Public URL for {self} is not available. Public base URL: '{self.public_base_url}'"
if raise_errors:
raise ValueError(msg)
else:
logger.error(msg)
return None
# Ensure url has a scheme
if not urllib.parse.urlparse(url).netloc:
msg = f"Public URL for {self} is invalid: {url}. Public base URL: '{self.public_base_url}'"
if raise_errors:
raise ValueError(msg)
else:
logger.error(msg)
return None
else:
return url

# backwards compatibility
url = public_url

def get_detections_count(self) -> int:
return self.detections.distinct().count()

def get_base_url(self) -> str:
def get_base_url(self) -> str | None:
"""
Determine the public URL from the deployment's data source.

If there is no data source, return a relative URL.
If there is no data source, return None

If the public_base_url is None, a presigned URL will be generated for each request.
"""
if self.deployment and self.deployment.data_source and self.deployment.data_source.public_base_url:
return self.deployment.data_source.public_base_url
else:
return "/"
return None

def extract_timestamp(self) -> datetime.datetime | None:
"""
Expand Down Expand Up @@ -1695,7 +1711,16 @@ class Detection(BaseModel):
# upload_to="detections",
# ),
# )
path = models.CharField(max_length=255, blank=True, null=True)
path = models.CharField(
max_length=255,
blank=True,
null=True,
help_text=(
"Either a full URL to a cropped detection image or a relative path to a file in the default "
"project storage. @TODO ensure all detection crops are hosted in the project storage, "
"not the default media storage. Migrate external URLs."
),
)

occurrence = models.ForeignKey(
"Occurrence",
Expand Down Expand Up @@ -1912,8 +1937,10 @@ def duration_label(self) -> str | None:
return ami.utils.dates.format_timedelta(duration)

def detection_images(self, limit=None):
for url in Detection.objects.filter(occurrence=self).exclude(path=None).values_list("path", flat=True)[:limit]:
yield urllib.parse.urljoin(_CROPS_URL_BASE, url)
for path in (
Detection.objects.filter(occurrence=self).exclude(path=None).values_list("path", flat=True)[:limit]
):
yield get_media_url(path)

@functools.cached_property
def best_detection(self):
Expand Down
65 changes: 65 additions & 0 deletions ami/ml/management/commands/create_missing_detection_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from django.core.management.base import BaseCommand
from tqdm import tqdm

from ami.main.models import SourceImage
from ami.ml.media import create_detection_images_from_source_image, get_source_images_with_missing_detection_images


class Command(BaseCommand):
help = "Create crops for detections with missing paths"

def add_arguments(self, parser):
parser.add_argument("--project", type=int, help="Project ID to process")
parser.add_argument("--batch-size", type=int, default=100, help="Batch size for processing")

def handle(self, *args, **options):
project_id = options["project"]
batch_size = options["batch_size"]

queryset = SourceImage.objects.all()
if project_id:
queryset = queryset.filter(project_id=project_id)

total_images = get_source_images_with_missing_detection_images(queryset).count()
self.stdout.write(f"Found {total_images}+ source images with missing detection crops")

processed_images = 0
processed_detections = 0
errors: list[tuple[SourceImage, str]] = []

with tqdm(total=total_images, desc="Processing images", unit="img") as pbar:
while True:
# Exclude images that have known errors
queryset = queryset.exclude(id__in=[source_image.pk for source_image, _ in errors])
batch = get_source_images_with_missing_detection_images(queryset, batch_size)
if not batch:
break

for source_image in batch:
try:
processed_paths = create_detection_images_from_source_image(source_image)
processed_detections += len(processed_paths)
processed_images += 1
except Exception as e:
error_message = (
f"Error processing image {source_image} from project '{source_image.project}': {str(e)}"
)
self.stderr.write(error_message)
errors.append((source_image, error_message))
finally:
pbar.update(1)

self.stdout.write(
f"Processed {processed_images}/{total_images} images, {processed_detections} detections"
)

self.stdout.write(
self.style.SUCCESS(
f"Successfully processed {processed_images} images and {processed_detections} detections"
)
)

if errors:
self.stdout.write(self.style.WARNING(f"Encountered {len(errors)} errors:"))
for source_image, error_message in errors:
self.stdout.write(f" - {error_message}")
94 changes: 94 additions & 0 deletions ami/ml/media.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import io
import logging
import os

import numpy as np
import requests
from django.core.files.base import ContentFile
from django.core.files.storage import default_storage
from django.db.models import Exists, OuterRef, QuerySet
from PIL import Image

from ami.main.models import Detection, SourceImage

logger = logging.getLogger(__name__)


def get_source_images_with_missing_detection_images(
queryset: QuerySet[SourceImage] | None = None, batch_size: int = 100
) -> QuerySet[SourceImage]:
if queryset is None:
queryset = SourceImage.objects.all()

return queryset.filter(
Exists(Detection.objects.filter(source_image=OuterRef("pk"), path__isnull=True))
).prefetch_related("detections", "deployment__project")[:batch_size]


def fetch_image_content(url: str) -> bytes:
response = requests.get(url)
response.raise_for_status()
return response.content


def load_source_image(source_image: SourceImage) -> np.ndarray:
url = source_image.public_url(raise_errors=True)
assert url
image_content = fetch_image_content(url)
image = Image.open(io.BytesIO(image_content))
return np.array(image)


def crop_detection(image: np.ndarray, bbox: tuple[int, int, int, int]) -> Image.Image:
x1, y1, x2, y2 = bbox
# Check the bounding box is within the image and has a non-zero area
if x1 < 0 or y1 < 0 or x2 > image.shape[1] or y2 > image.shape[0]:
logger.warning(
f"Bounding box is outside the image. Image shape: {image.shape} Bounding box: {bbox}. "
"Clamping to image bounds."
)
# Set max and min values for x and y
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(image.shape[1], x2)
y2 = min(image.shape[0], y2)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Bounding box has zero area. Bounding box: {bbox} Width: {x2 - x1} Height: {y2 - y1}")
cropped_image = image[int(y1) : int(y2), int(x1) : int(x2)] # noqa: E203
img = Image.fromarray(cropped_image)
if not img.getbbox():
raise ValueError("Cropped image is empty")
return img


def save_crop(cropped_image: Image.Image, detection: Detection, source_image: SourceImage) -> str:
source_basename = os.path.splitext(os.path.basename(source_image.path))[0]
image_name = f"{source_basename}_detection_{detection.pk}.jpg"
iso_day = detection.timestamp.date().isoformat() if detection.timestamp else "unknown_date"
assert source_image.project, "Source image must belong to a project"
image_path = f"detections/{source_image.project.pk}/{iso_day}/{image_name}"

img_byte_arr = io.BytesIO()
cropped_image.save(img_byte_arr, format="JPEG")
img_byte_arr = img_byte_arr.getvalue()

return default_storage.save(image_path, ContentFile(img_byte_arr))


def create_detection_images_from_source_image(source_image: SourceImage) -> list[str]:
# Check if the source image has detections without images before loading the image
if not source_image.detections.filter(path__isnull=True).exists():
return []

image_np = load_source_image(source_image)
processed_paths = []

for detection in source_image.detections.filter(path__isnull=True):
if detection.bbox:
cropped_image = crop_detection(image_np, detection.bbox)
path = save_crop(cropped_image, detection, source_image)
detection.path = path
detection.save()
processed_paths.append(path)

return processed_paths
46 changes: 32 additions & 14 deletions ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TaxonRank,
update_calculated_fields_for_events,
)
from ami.ml.tasks import create_detection_images

from ..schemas import PipelineRequest, PipelineResponse, SourceImageRequest
from .algorithm import Algorithm
Expand Down Expand Up @@ -135,6 +136,7 @@ def process_images(
"""
job = None
images = list(images)
urls = [source_image.public_url() for source_image in images if source_image.public_url()]

if job_id:
from ami.jobs.models import Job
Expand All @@ -147,9 +149,10 @@ def process_images(
source_images=[
SourceImageRequest(
id=str(source_image.pk),
url=source_image.public_url(),
url=url,
)
for source_image in images
for source_image, url in zip(images, urls)
if url
],
)

Expand Down Expand Up @@ -224,9 +227,14 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m
detection_algorithm=detection_algo,
bbox=list(detection_resp.bbox.dict().values()),
).first()
# Ensure that the crop image URL is not empty or only a slash. None is fine.
if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"):
crop_url = detection_resp.crop_image_url
else:
crop_url = None
if existing_detection:
if not existing_detection.path:
existing_detection.path = detection_resp.crop_image_url or None
existing_detection.path = crop_url
existing_detection.save()
print("Updated existing detection", existing_detection)
detection = existing_detection
Expand All @@ -235,7 +243,7 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m
source_image=source_image,
bbox=list(detection_resp.bbox.dict().values()),
timestamp=source_image.timestamp,
path=detection_resp.crop_image_url or "",
path=crop_url,
detection_time=detection_resp.timestamp,
detection_algorithm=detection_algo,
)
Expand Down Expand Up @@ -276,16 +284,20 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m
# or do we use the bbox as a unique identifier?
# then it doesn't matter what detection algorithm was used

new_classification = Classification()
new_classification.detection = detection
new_classification.taxon = taxon
new_classification.algorithm = classification_algo
new_classification.score = max(classification.scores)
new_classification.timestamp = now() # @TODO get timestamp from API response
# @TODO add reference to job or pipeline?
new_classification, created = Classification.objects.get_or_create(
detection=detection,
taxon=taxon,
algorithm=classification_algo,
score=max(classification.scores),
defaults={"timestamp": classification.timestamp or now()},
)

new_classification.save()
created_objects.append(new_classification)
if created:
# Optionally add reference to job or pipeline here
created_objects.append(new_classification)
else:
# Optionally handle the case where a duplicate is found
logger.warn("Duplicate classification found, not creating a new one.")

# Create a new occurrence for each detection (no tracking yet)
# @TODO remove when we implement tracking
Expand All @@ -297,7 +309,7 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m
determination=taxon,
determination_score=new_classification.score,
)
detection.occurrence = occurrence
detection.occurrence = occurrence # type: ignore
detection.save()
detection.occurrence.save()

Expand All @@ -306,6 +318,12 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m
for source_image in source_images:
source_image.save()

image_cropping_task = create_detection_images.delay(
source_image_ids=[source_image.pk for source_image in source_images],
)
if job:
job.logger.info(f"Creating detection images in sub-task {image_cropping_task.id}")

event_ids = [img.event_id for img in source_images]
update_calculated_fields_for_events(pks=event_ids)

Expand Down
Loading