Skip to content

Commit e0834ad

Browse files
committed
feat: add filters
1 parent 8ed27ee commit e0834ad

File tree

4 files changed

+147
-8
lines changed

4 files changed

+147
-8
lines changed

ami/main/models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,37 @@ def top_n(self, n: int = 3) -> list[dict[str, "Taxon | float | None"]]:
20312031
for i, s in top_scored
20322032
]
20332033

2034+
def genus_scores_with_taxa(self) -> typing.Iterable[tuple[str, float]]:
2035+
"""
2036+
Return the genus scores for this classification using the category map.
2037+
"""
2038+
raise NotImplementedError
2039+
predictions = self.predictions_with_taxa()
2040+
genus_scores = {}
2041+
for taxon, score in predictions:
2042+
genus = taxon.get_parent(rank="GENUS")
2043+
if genus:
2044+
genus_scores[genus] = genus_scores.get(genus, 0) + score
2045+
return sorted(genus_scores.items(), key=lambda x: x[1], reverse=True)
2046+
2047+
def genus_scores_by_splitting_names(self) -> typing.Iterable[tuple["Taxon", float]]:
2048+
"""
2049+
Return the genus scores for this classification using the category map.
2050+
"""
2051+
predictions = self.predictions()
2052+
genus_scores = {}
2053+
for taxon_name, score in predictions:
2054+
genus_name = taxon_name.split(" ")[0]
2055+
if genus_name:
2056+
genus_scores[genus_name] = genus_scores.get(genus_name, 0) + score
2057+
2058+
# Get or make actual Taxon objects
2059+
genus_scores = {
2060+
Taxon.objects.get_or_create(name=genus_name, rank="GENUS")[0]: score
2061+
for genus_name, score in genus_scores.items()
2062+
}
2063+
return sorted(genus_scores.items(), key=lambda x: x[1], reverse=True)
2064+
20342065
def get_similar_classifications(self, distance_metric="cosine") -> models.QuerySet:
20352066
"""
20362067
Return most similar classifications based on feature_2048 embeddings.

ami/ml/clustering_algorithms/agglomerative.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,47 @@ def setup(self, data_dict):
7070
data_dict["val"]["feat_list"], data_dict["val"]["label_list"]
7171
)
7272

73+
def cluster_by_higher_taxon(self, features, rel_sizes, predictions, taxon_map, taxon_rank)-> tuple[np.ndarray, np.ndarray]:
74+
75+
76+
# map labdls -> [species, genus, family]
77+
78+
# [1, 0, 4, 5, ...]
79+
# [sp.1, sp.2, ...]
80+
# [genus1, ]
81+
# TODO: create taxon_mask based on predictions and taxon_map
82+
taxons = taxon_map[predictions]
83+
84+
cluster_id_offset = 0
85+
86+
all_cluster_ids = []
87+
all_silhouette_scores = []
88+
89+
for taxon in taxons:
90+
taxon_features = features[predictions.isin(taxon)] #TODO: change this
91+
taxon_rel_sizes = rel_sizes[predictions.isin(taxon)] #TODO: change this
92+
93+
cluster_ids, silhouette_scores = self.cluster(taxon_features, taxon_rel_sizes)
94+
cluster_ids += cluster_id_offset
95+
96+
cluster_id_offset += len(np.unique(cluster_ids))
97+
98+
all_cluster_ids.append(cluster_ids)
99+
all_silhouette_scores.append(silhouette_scores)
100+
101+
all_cluster_ids = np.concatenate(all_cluster_ids, axis=0)
102+
all_silhouette_scores = np.concatenate(all_silhouette_scores, axis=0)
103+
104+
return all_cluster_ids, all_silhouette_scores
105+
106+
107+
108+
109+
110+
111+
112+
113+
73114
def cluster(self, features, rel_sizes) -> tuple[np.ndarray, np.ndarray]:
74115
logger.info(f"distance threshold: {self.distance_threshold}")
75116
logger.info("features shape: %s", features.shape)

ami/ml/clustering_algorithms/cluster_detections.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import logging
33
import typing
44

5+
# import cv2
56
import numpy as np
67
from django.db.models import Count
78
from django.utils.timezone import now
9+
from PIL import Image
810

911
from ami.ml.clustering_algorithms.utils import get_clusterer
1012

@@ -85,6 +87,58 @@ def get_cluster_name(cluster_id: int, taxon: "Taxon | None" = None, job: "Job |
8587

8688
return " ".join(part for part in parts if part)
8789

90+
def remove_detection_on_edge(detection):
91+
92+
bbox = detection.bbox
93+
img_width, img_height = detection.source_image.width, detection.source_image.height
94+
95+
# left
96+
if bbox[0] < 1:
97+
return True
98+
99+
# top
100+
if bbox[1] < 1:
101+
return True
102+
103+
# right
104+
if bbox[2] > img_width - 2:
105+
return True
106+
107+
108+
if bbox[3] > img_height -2 :
109+
return True
110+
111+
return False
112+
113+
114+
def get_relative_size(detection):
115+
bbox_width, bbox_height = detection.width(), detection.height()
116+
img_width, img_height = detection.source_image.width, detection.source_image.height
117+
detection.source_image.deployment
118+
assert img_width and img_height
119+
relative_size = (bbox_width * bbox_height) / (img_width * img_height)
120+
return relative_size
121+
122+
def compute_sharpness(image_path):
123+
image = Image.open(image_path).convert('L')
124+
image_array = np.array(image, dtype=np.float32)
125+
126+
# Define Laplacian kernel
127+
kernel = np.array([[0, 1, 0],
128+
[1, -4, 1],
129+
[0, 1, 0]], dtype=np.float32)
130+
131+
padded = np.pad(image_array, pad_width=1, mode='reflect')
132+
laplacian = np.zeros_like(image_array)
133+
134+
for i in range(image_array.shape[0]):
135+
for j in range(image_array.shape[1]):
136+
region = padded[i:i+3, j:j+3]
137+
laplacian[i, j] = np.sum(region * kernel)
138+
139+
laplacian_std = np.std(laplacian)
140+
141+
return laplacian_std
88142

89143
def cluster_detections(
90144
collection, params: dict, task_logger: logging.Logger = logger, job=None
@@ -94,6 +148,9 @@ def cluster_detections(
94148
from ami.ml.models import Algorithm
95149
from ami.ml.models.pipeline import create_and_update_occurrences_for_detections
96150

151+
sharpness_threshold = 8
152+
relative_size_threshold = 0.1 # TODO: this should be updated
153+
97154
ood_threshold = params.get("ood_threshold", 1)
98155
feature_extraction_algorithm = params.get("feature_extraction_algorithm", None)
99156
algorithm = params.get("clustering_algorithm", "agglomerative")
@@ -123,18 +180,28 @@ def cluster_detections(
123180
valid_classifications = []
124181
update_job_progress(job, stage_key="feature_collection", status=JobState.STARTED, progress=0.0)
125182
# Collecting features for detections
183+
184+
126185
for idx, detection in enumerate(detections):
186+
127187
classification = detection.classifications.filter(
128188
features_2048__isnull=False,
129189
algorithm=feature_extraction_algorithm,
130190
).first()
191+
131192
if classification:
193+
if remove_detection_on_edge(detection): # remove crops that are on the edge
194+
continue
195+
relative_size = get_relative_size(detection)
196+
197+
if relative_size < relative_size_threshold: # remove small crops
198+
continue
199+
200+
sharpness = compute_sharpness(detection.path) # remove blurry images
201+
if sharpness < sharpness_threshold:
202+
continue
203+
132204
features.append(classification.features_2048)
133-
bbox_width, bbox_height = detection.width(), detection.height()
134-
img_width, img_height = detection.source_image.width, detection.source_image.height
135-
detection.source_image.deployment
136-
assert img_width and img_height
137-
relative_size = (bbox_width * bbox_height) / (img_width * img_height)
138205
sizes.append(relative_size)
139206
valid_detections.append(detection)
140207
valid_classifications.append(classification)
@@ -160,7 +227,7 @@ def cluster_detections(
160227
if not ClusteringAlgorithm:
161228
raise ValueError(f"Unsupported clustering algorithm: {algorithm}")
162229

163-
cluster_ids, cluster_scores = ClusteringAlgorithm(params).cluster(features_np, size_np)
230+
cluster_ids, cluster_scores = ClusteringAlgorithm(params).cluster(features_np, size_np) # TODO: change this
164231

165232
task_logger.info(f"Clustering completed with {len(set(cluster_ids))} clusters")
166233
clusters: dict[int, list[ClusterMember]] = {}

ami/ml/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
create_captures,
2525
create_captures_from_files,
2626
create_detections,
27-
create_processing_service,
2827
create_taxa,
2928
group_images_into_events,
3029
setup_test_project,
@@ -732,6 +731,7 @@ def populate_collection_with_detections(self):
732731
def _populate_detection_features(self):
733732
"""Populate detection features with random values."""
734733
classifier = Algorithm.objects.get(key="random-species-classifier")
734+
taxon = Taxon.objects.last()
735735
for detection in self.detections:
736736
detection.associate_new_occurrence()
737737
# Create a random feature vector
@@ -740,7 +740,7 @@ def _populate_detection_features(self):
740740
classification = Classification.objects.create(
741741
detection=detection,
742742
algorithm=classifier,
743-
taxon=None,
743+
taxon=taxon,
744744
score=0.5,
745745
ood_score=0.5,
746746
features_2048=feature_vector,

0 commit comments

Comments
 (0)