26
26
27
27
28
28
from ultralytics import FastSAM
29
- from ultralytics .models .fastsam import FastSAMPrompt
29
+ from ultralytics .models .fastsam import FastSAMPredictor
30
30
from ultralytics .models .sam import Predictor as SAMPredictor
31
31
import fire
32
- import numpy as np
33
32
import ultralytics
34
33
35
34
from openadapt import cache
41
40
SAM_MODEL_NAMES = (
42
41
"sam_b.pt" , # base
43
42
"sam_l.pt" , # large
44
- # "mobile_sam.pt",
45
43
)
46
44
MODEL_NAMES = FASTSAM_MODEL_NAMES + SAM_MODEL_NAMES
47
45
DEFAULT_MODEL_NAME = MODEL_NAMES [0 ]
48
46
49
47
50
- # TODO: rename
51
48
def fetch_segmented_image (
52
49
image : Image .Image ,
53
50
model_name : str = DEFAULT_MODEL_NAME ,
@@ -74,14 +71,12 @@ def fetch_segmented_image(
74
71
def do_fastsam (
75
72
image : Image ,
76
73
model_name : str ,
77
- # TODO: inject from config
78
74
device : str = "cpu" ,
79
75
retina_masks : bool = True ,
80
76
imgsz : int | tuple [int , int ] | None = 1024 ,
81
- # threshold below which boxes will be filtered out
82
77
min_confidence_threshold : float = 0.4 ,
83
- # discards all overlapping boxes with IoU > iou_threshold
84
78
max_iou_threshold : float = 0.9 ,
79
+ max_det : int = 1000 ,
85
80
max_retries : int = 5 ,
86
81
retry_delay_seconds : float = 0.1 ,
87
82
) -> Image :
@@ -90,100 +85,35 @@ def do_fastsam(
90
85
For usage of thresholds see:
91
86
github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
92
87
/ultralytics/utils/ops.py#L164
93
-
94
- Args:
95
- TODO
96
- min_confidence_threshold (float, optional): The minimum confidence score
97
- that a detection must meet or exceed to be considered valid. Detections
98
- below this threshold will not be marked. Defaults to 0.00.
99
- max_iou_threshold (float, optional): The maximum allowed Intersection over
100
- Union (IoU) value for overlapping detections. Detections that exceed this
101
- IoU threshold are considered for suppression, keeping only the
102
- detection with the highest confidence. Defaults to 0.05.
103
88
"""
104
89
model = FastSAM (model_name )
105
-
106
90
imgsz = imgsz or image .size
107
91
108
- # Run inference on image
109
92
everything_results = model (
110
93
image ,
111
94
device = device ,
112
95
retina_masks = retina_masks ,
113
96
imgsz = imgsz ,
114
97
conf = min_confidence_threshold ,
115
98
iou = max_iou_threshold ,
99
+ max_det = max_det ,
116
100
)
117
-
118
- # Prepare a Prompt Process object
119
- prompt_process = FastSAMPrompt (image , everything_results , device = "cpu" )
120
-
121
- # Everything prompt
122
- annotations = prompt_process .everything_prompt ()
123
-
124
- # TODO: support other modes once issues are fixed
125
- # https://github.com/ultralytics/ultralytics/issues/13218#issuecomment-2142960103
126
-
127
- # Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
128
- # annotations = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
129
-
130
- # Text prompt
131
- # annotations = prompt_process.text_prompt(text='a photo of a dog')
132
-
133
- # Point prompt
134
- # points default [[0,0]] [[x1,y1],[x2,y2]]
135
- # point_label default [0] [1,0] 0:background, 1:foreground
136
- # annotations = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
137
-
138
- assert len (annotations ) == 1 , len (annotations )
139
- annotation = annotations [0 ]
140
-
141
- # hide original image
142
- annotation .orig_img = np .ones (annotation .orig_img .shape )
143
-
144
- # TODO: in memory, e.g. with prompt_process.fast_show_mask()
145
- with TemporaryDirectory () as tmp_dir :
146
- # Force the output format to PNG to prevent JPEG compression artefacts
147
- annotation .path = annotation .path .replace (".jpg" , ".png" )
148
- prompt_process .plot (
149
- [annotation ],
150
- tmp_dir ,
151
- with_contours = False ,
152
- retina = False ,
101
+ assert len (everything_results ) == 1 , len (everything_results )
102
+ annotation = everything_results [0 ]
103
+
104
+ segmented_image = Image .fromarray (
105
+ annotation .plot (
106
+ img = np .ones (annotation .orig_img .shape , dtype = annotation .orig_img .dtype ),
107
+ kpt_line = False ,
108
+ labels = False ,
109
+ boxes = False ,
110
+ probs = False ,
111
+ color_mode = "instance" ,
153
112
)
154
- result_name = os .path .basename (annotation .path )
155
- logger .info (f"{ annotation .path = } " )
156
- segmented_image_path = Path (tmp_dir ) / result_name
157
- segmented_image = Image .open (segmented_image_path )
158
-
159
- # Ensure the image is fully loaded before deletion to avoid errors or incomplete operations,
160
- # as some operating systems and file systems lock files during read or processing.
161
- segmented_image .load ()
162
-
163
- # Attempt to delete the file with retries and delay
164
- retries = 0
165
-
166
- while retries < max_retries :
167
- try :
168
- os .remove (segmented_image_path )
169
- break # If deletion succeeds, exit loop
170
- except OSError as e :
171
- if e .errno == errno .ENOENT : # File not found
172
- break
173
- else :
174
- retries += 1
175
- time .sleep (retry_delay_seconds )
176
-
177
- if retries == max_retries :
178
- logger .warning (f"Failed to delete { segmented_image_path } " )
179
- # Check if the dimensions of the original and segmented images differ
180
- # XXX TODO this is a hack, this plotting code should be refactored, but the
181
- # bug may exist in ultralytics, since they seem to resize as well; see:
182
- # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/plotting.py#L238
183
- # https://github.com/ultralytics/ultralytics/issues/561#issuecomment-1403079910
113
+ )
114
+
184
115
if image .size != segmented_image .size :
185
116
logger .warning (f"{ image .size = } != { segmented_image .size = } , resizing..." )
186
- # Resize segmented_image to match original using nearest neighbor interpolation
187
117
segmented_image = segmented_image .resize (image .size , Image .NEAREST )
188
118
189
119
assert image .size == segmented_image .size , (image .size , segmented_image .size )
@@ -194,7 +124,6 @@ def do_fastsam(
194
124
def do_sam (
195
125
image : Image .Image ,
196
126
model_name : str ,
197
- # TODO: add params
198
127
) -> Image .Image :
199
128
# Create SAMPredictor
200
129
overrides = dict (
@@ -207,20 +136,7 @@ def do_sam(
207
136
predictor = SAMPredictor (overrides = overrides )
208
137
209
138
# Segment with additional args
210
- # results = predictor(source=image, crop_n_layers=1, points_stride=64)
211
- results = predictor (
212
- source = image ,
213
- # crop_n_layers=3,
214
- # crop_overlap_ratio=0.5,
215
- # crop_downscale_factor=1,
216
- # point_grids=None,
217
- # points_stride=12,
218
- # points_batch_size=128,
219
- # conf_thres=0.8,
220
- # stability_score_thresh=0.95,
221
- # stability_score_offset=0.95,
222
- # crop_nms_thresh=0.8,
223
- )
139
+ results = predictor (source = image )
224
140
mask_ims = results_to_mask_images (results )
225
141
segmented_image = colorize_masks (mask_ims )
226
142
return segmented_image
@@ -238,8 +154,7 @@ def results_to_mask_images(
238
154
239
155
240
156
def colorize_masks (masks : list [Image .Image ]) -> Image .Image :
241
- """
242
- Takes a list of PIL images containing binary masks and returns a new PIL.Image
157
+ """Takes a list of PIL images containing binary masks and returns a new PIL.Image
243
158
where each mask is colored differently using a unique color for each mask.
244
159
245
160
Args:
@@ -249,15 +164,11 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
249
164
PIL.Image: A new image with each mask in a different color.
250
165
"""
251
166
if not masks :
252
- return None # Return None if the list is empty
167
+ return None
253
168
254
- # Assuming all masks are the same size, get dimensions
255
169
width , height = masks [0 ].size
256
-
257
- # Create an empty array with 3 color channels (RGB)
258
170
result_image = np .zeros ((height , width , 3 ), dtype = np .uint8 )
259
171
260
- # Generate unique colors using HSV color space
261
172
num_masks = len (masks )
262
173
colors = [
263
174
tuple (
@@ -271,17 +182,12 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
271
182
]
272
183
273
184
for idx , mask in enumerate (masks ):
274
- # Convert PIL Image to numpy array
275
185
mask_array = np .array (mask )
276
-
277
- # Apply the color to the mask
278
186
for c in range (3 ):
279
- # Only colorize where the mask is True (assuming mask is binary: 0 or 255)
280
187
result_image [:, :, c ] += (mask_array / 255 * colors [idx ][c ]).astype (
281
188
np .uint8
282
189
)
283
190
284
- # Convert the result back to a PIL image
285
191
return Image .fromarray (result_image )
286
192
287
193
0 commit comments