2
2
import logging
3
3
import typing
4
4
5
+ # import cv2
5
6
import numpy as np
6
7
from django .db .models import Count
7
8
from django .utils .timezone import now
9
+ from PIL import Image
8
10
9
11
from ami .ml .clustering_algorithms .utils import get_clusterer
10
12
@@ -85,6 +87,58 @@ def get_cluster_name(cluster_id: int, taxon: "Taxon | None" = None, job: "Job |
85
87
86
88
return " " .join (part for part in parts if part )
87
89
90
+ def remove_detection_on_edge (detection ):
91
+
92
+ bbox = detection .bbox
93
+ img_width , img_height = detection .source_image .width , detection .source_image .height
94
+
95
+ # left
96
+ if bbox [0 ] < 1 :
97
+ return True
98
+
99
+ # top
100
+ if bbox [1 ] < 1 :
101
+ return True
102
+
103
+ # right
104
+ if bbox [2 ] > img_width - 2 :
105
+ return True
106
+
107
+
108
+ if bbox [3 ] > img_height - 2 :
109
+ return True
110
+
111
+ return False
112
+
113
+
114
+ def get_relative_size (detection ):
115
+ bbox_width , bbox_height = detection .width (), detection .height ()
116
+ img_width , img_height = detection .source_image .width , detection .source_image .height
117
+ detection .source_image .deployment
118
+ assert img_width and img_height
119
+ relative_size = (bbox_width * bbox_height ) / (img_width * img_height )
120
+ return relative_size
121
+
122
+ def compute_sharpness (image_path ):
123
+ image = Image .open (image_path ).convert ('L' )
124
+ image_array = np .array (image , dtype = np .float32 )
125
+
126
+ # Define Laplacian kernel
127
+ kernel = np .array ([[0 , 1 , 0 ],
128
+ [1 , - 4 , 1 ],
129
+ [0 , 1 , 0 ]], dtype = np .float32 )
130
+
131
+ padded = np .pad (image_array , pad_width = 1 , mode = 'reflect' )
132
+ laplacian = np .zeros_like (image_array )
133
+
134
+ for i in range (image_array .shape [0 ]):
135
+ for j in range (image_array .shape [1 ]):
136
+ region = padded [i :i + 3 , j :j + 3 ]
137
+ laplacian [i , j ] = np .sum (region * kernel )
138
+
139
+ laplacian_std = np .std (laplacian )
140
+
141
+ return laplacian_std
88
142
89
143
def cluster_detections (
90
144
collection , params : dict , task_logger : logging .Logger = logger , job = None
@@ -94,6 +148,9 @@ def cluster_detections(
94
148
from ami .ml .models import Algorithm
95
149
from ami .ml .models .pipeline import create_and_update_occurrences_for_detections
96
150
151
+ sharpness_threshold = 8
152
+ relative_size_threshold = 0.1 # TODO: this should be updated
153
+
97
154
ood_threshold = params .get ("ood_threshold" , 1 )
98
155
feature_extraction_algorithm = params .get ("feature_extraction_algorithm" , None )
99
156
algorithm = params .get ("clustering_algorithm" , "agglomerative" )
@@ -123,18 +180,28 @@ def cluster_detections(
123
180
valid_classifications = []
124
181
update_job_progress (job , stage_key = "feature_collection" , status = JobState .STARTED , progress = 0.0 )
125
182
# Collecting features for detections
183
+
184
+
126
185
for idx , detection in enumerate (detections ):
186
+
127
187
classification = detection .classifications .filter (
128
188
features_2048__isnull = False ,
129
189
algorithm = feature_extraction_algorithm ,
130
190
).first ()
191
+
131
192
if classification :
193
+ if remove_detection_on_edge (detection ): # remove crops that are on the edge
194
+ continue
195
+ relative_size = get_relative_size (detection )
196
+
197
+ if relative_size < relative_size_threshold : # remove small crops
198
+ continue
199
+
200
+ sharpness = compute_sharpness (detection .path ) # remove blurry images
201
+ if sharpness < sharpness_threshold :
202
+ continue
203
+
132
204
features .append (classification .features_2048 )
133
- bbox_width , bbox_height = detection .width (), detection .height ()
134
- img_width , img_height = detection .source_image .width , detection .source_image .height
135
- detection .source_image .deployment
136
- assert img_width and img_height
137
- relative_size = (bbox_width * bbox_height ) / (img_width * img_height )
138
205
sizes .append (relative_size )
139
206
valid_detections .append (detection )
140
207
valid_classifications .append (classification )
@@ -160,7 +227,7 @@ def cluster_detections(
160
227
if not ClusteringAlgorithm :
161
228
raise ValueError (f"Unsupported clustering algorithm: { algorithm } " )
162
229
163
- cluster_ids , cluster_scores = ClusteringAlgorithm (params ).cluster (features_np , size_np )
230
+ cluster_ids , cluster_scores = ClusteringAlgorithm (params ).cluster (features_np , size_np ) # TODO: change this
164
231
165
232
task_logger .info (f"Clustering completed with { len (set (cluster_ids ))} clusters" )
166
233
clusters : dict [int , list [ClusterMember ]] = {}
0 commit comments