Skip to content

Commit db8a155

Browse files
authored
Add support for filtering objets by size (#333)
1 parent 837f5a5 commit db8a155

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

samgeo/samgeo.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def generate(
158158
erosion_kernel=None,
159159
mask_multiplier=255,
160160
unique=True,
161+
min_size=0,
162+
max_size=None,
161163
**kwargs,
162164
):
163165
"""Generate masks for the input image.
@@ -180,6 +182,9 @@ def generate(
180182
The parameter is ignored if unique is True.
181183
unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
182184
The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
185+
min_size (int, optional): The minimum size of the objects. Defaults to 0.
186+
max_size (int, optional): The maximum size of the objects. Defaults to None.
187+
**kwargs: Other arguments for save_masks().
183188
184189
"""
185190

@@ -221,10 +226,19 @@ def generate(
221226
masks = mask_generator.generate(image) # Segment the input image
222227
self.masks = masks # Store the masks as a list of dictionaries
223228
self.batch = False
229+
self._min_size = min_size
230+
self._max_size = max_size
224231

225232
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
226233
self.save_masks(
227-
output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
234+
output,
235+
foreground,
236+
unique,
237+
erosion_kernel,
238+
mask_multiplier,
239+
min_size,
240+
max_size,
241+
**kwargs,
228242
)
229243

230244
def save_masks(
@@ -234,6 +248,8 @@ def save_masks(
234248
unique=True,
235249
erosion_kernel=None,
236250
mask_multiplier=255,
251+
min_size=0,
252+
max_size=None,
237253
**kwargs,
238254
):
239255
"""Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
@@ -246,6 +262,9 @@ def save_masks(
246262
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
247263
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
248264
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
265+
min_size (int, optional): The minimum size of the objects. Defaults to 0.
266+
max_size (int, optional): The maximum size of the objects. Defaults to None.
267+
**kwargs: Other arguments for array_to_image().
249268
250269
"""
251270

@@ -279,6 +298,10 @@ def save_masks(
279298
count = len(sorted_masks)
280299
for index, ann in enumerate(sorted_masks):
281300
m = ann["segmentation"]
301+
if min_size > 0 and ann["area"] < min_size:
302+
continue
303+
if max_size is not None and ann["area"] > max_size:
304+
continue
282305
objects[m] = count - index
283306

284307
# Generate a binary mask
@@ -290,6 +313,10 @@ def save_masks(
290313
resulting_borders = np.zeros((h, w), dtype=dtype)
291314

292315
for m in masks:
316+
if min_size > 0 and m["area"] < min_size:
317+
continue
318+
if max_size is not None and m["area"] > max_size:
319+
continue
293320
mask = (m["segmentation"] > 0).astype(dtype)
294321
resulting_mask += mask
295322

@@ -384,6 +411,14 @@ def show_anns(
384411
)
385412
img[:, :, 3] = 0
386413
for ann in sorted_anns:
414+
if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
415+
continue
416+
if (
417+
hasattr(self, "_max_size")
418+
and isinstance(self._max_size, int)
419+
and ann["area"] > self._max_size
420+
):
421+
continue
387422
m = ann["segmentation"]
388423
color_mask = np.concatenate([np.random.random(3), [alpha]])
389424
img[m] = color_mask

samgeo/samgeo2.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def generate(
192192
erosion_kernel: Optional[Tuple[int, int]] = None,
193193
mask_multiplier: int = 255,
194194
unique: bool = True,
195+
min_size: int = 0,
196+
max_size: int = None,
195197
**kwargs: Any,
196198
) -> List[Dict[str, Any]]:
197199
"""
@@ -215,6 +217,8 @@ def generate(
215217
Defaults to True.
216218
The unique value increases from 1 to the number of objects. The
217219
larger the number, the larger the object area.
220+
min_size (int): The minimum size of the object. Defaults to 0.
221+
max_size (int): The maximum size of the object. Defaults to None.
218222
**kwargs (Any): Additional keyword arguments.
219223
220224
Returns:
@@ -241,11 +245,20 @@ def generate(
241245
mask_generator = self.mask_generator # The automatic mask generator
242246
masks = mask_generator.generate(image) # Segment the input image
243247
self.masks = masks # Store the masks as a list of dictionaries
248+
self._min_size = min_size
249+
self._max_size = max_size
244250

245251
if output is not None:
246252
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
247253
self.save_masks(
248-
output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
254+
output,
255+
foreground,
256+
unique,
257+
erosion_kernel,
258+
mask_multiplier,
259+
min_size,
260+
max_size,
261+
**kwargs,
249262
)
250263

251264
def save_masks(
@@ -255,6 +268,8 @@ def save_masks(
255268
unique: bool = True,
256269
erosion_kernel: Optional[Tuple[int, int]] = None,
257270
mask_multiplier: int = 255,
271+
min_size: int = 0,
272+
max_size: int = None,
258273
**kwargs: Any,
259274
) -> None:
260275
"""Save the masks to the output path. The output is either a binary mask
@@ -275,6 +290,9 @@ def save_masks(
275290
mask, which is usually a binary mask [0, 1]. You can use this
276291
parameter to scale the mask to a larger range, for example
277292
[0, 255]. Defaults to 255.
293+
min_size (int, optional): The minimum size of the object. Defaults to 0.
294+
max_size (int, optional): The maximum size of the object. Defaults to None.
295+
**kwargs: Additional keyword arguments for common.array_to_image().
278296
"""
279297

280298
if self.masks is None:
@@ -307,6 +325,10 @@ def save_masks(
307325
count = len(sorted_masks)
308326
for index, ann in enumerate(sorted_masks):
309327
m = ann["segmentation"]
328+
if min_size > 0 and ann["area"] < min_size:
329+
continue
330+
if max_size is not None and ann["area"] > max_size:
331+
continue
310332
objects[m] = count - index
311333

312334
# Generate a binary mask
@@ -318,6 +340,10 @@ def save_masks(
318340
resulting_borders = np.zeros((h, w), dtype=dtype)
319341

320342
for m in masks:
343+
if min_size > 0 and m["area"] < min_size:
344+
continue
345+
if max_size is not None and m["area"] > max_size:
346+
continue
321347
mask = (m["segmentation"] > 0).astype(dtype)
322348
resulting_mask += mask
323349

@@ -415,6 +441,14 @@ def show_anns(
415441
)
416442
img[:, :, 3] = 0
417443
for ann in sorted_anns:
444+
if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
445+
continue
446+
if (
447+
hasattr(self, "_max_size")
448+
and isinstance(self._max_size, int)
449+
and ann["area"] > self._max_size
450+
):
451+
continue
418452
m = ann["segmentation"]
419453
color_mask = np.concatenate([np.random.random(3), [alpha]])
420454
img[m] = color_mask

0 commit comments

Comments
 (0)