Skip to content

Commit 409953a

Browse files
Add detection_filter to predict() method to allow for user-defined logic (#307)
* Added a `detection_filter` parameter to `predict` method to allow for filtering of detections based on user-defined logic; appears to seamlessly work with `batch_predict` also; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9894762 commit 409953a

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

samgeo/text_sam.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
Credits to Luca Medeiros for the original implementation.
44
"""
55

6+
import argparse
7+
import inspect
68
import os
79
import warnings
8-
import argparse
10+
911
import numpy as np
1012
import torch
1113
from PIL import Image
@@ -238,6 +240,7 @@ def predict(
238240
save_args={},
239241
return_results=False,
240242
return_coords=False,
243+
detection_filter=None,
241244
**kwargs,
242245
):
243246
"""
@@ -253,6 +256,10 @@ def predict(
253256
dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
254257
save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
255258
return_results (bool, optional): Whether to return the results. Defaults to False.
259+
detection_filter (callable, optional):
260+
Callable which with box, mask, logit, phrase, and index args returns a boolean.
261+
If provided, the function will be called for each detected object.
262+
Defaults to None.
256263
257264
Returns:
258265
tuple: Tuple containing masks, boxes, phrases, and logits.
@@ -312,12 +319,34 @@ def predict(
312319
image_np[..., 0], dtype=dtype
313320
) # Adjusted for single channel
314321

315-
for i, (box, mask) in enumerate(zip(boxes, masks)):
322+
# Validate the detection_filter argument
323+
if detection_filter is not None:
324+
325+
if not callable(detection_filter):
326+
raise ValueError("detection_filter must be callable.")
327+
328+
req_nargs = 6 if inspect.ismethod(detection_filter) else 5
329+
if not len(inspect.signature(detection_filter).parameters) == req_nargs:
330+
raise ValueError(
331+
"detection_filter required args: "
332+
"box, mask, logit, phrase, and index."
333+
)
334+
335+
for i, (box, mask, logit, phrase) in enumerate(
336+
zip(boxes, masks, logits, phrases)
337+
):
338+
316339
# Convert tensor to numpy array if necessary and ensure it contains integers
317340
if isinstance(mask, torch.Tensor):
318341
mask = (
319342
mask.cpu().numpy().astype(dtype)
320343
) # If mask is on GPU, use .cpu() before .numpy()
344+
345+
# Apply the user-supplied filtering logic if provided
346+
if detection_filter is not None:
347+
if not detection_filter(box, mask, logit, phrase, i):
348+
continue
349+
321350
mask_overlay += ((mask > 0) * (i + 1)).astype(
322351
dtype
323352
) # Assign a unique value for each mask

0 commit comments

Comments
 (0)