Skip to content

Fix/replay #900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 19 additions & 113 deletions openadapt/adapters/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@


from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
from ultralytics.models.fastsam import FastSAMPredictor
from ultralytics.models.sam import Predictor as SAMPredictor
import fire
import numpy as np
import ultralytics

from openadapt import cache
Expand All @@ -41,13 +40,11 @@
SAM_MODEL_NAMES = (
"sam_b.pt", # base
"sam_l.pt", # large
# "mobile_sam.pt",
)
MODEL_NAMES = FASTSAM_MODEL_NAMES + SAM_MODEL_NAMES
DEFAULT_MODEL_NAME = MODEL_NAMES[0]


# TODO: rename
def fetch_segmented_image(
image: Image.Image,
model_name: str = DEFAULT_MODEL_NAME,
Expand All @@ -74,14 +71,12 @@ def fetch_segmented_image(
def do_fastsam(
image: Image,
model_name: str,
# TODO: inject from config
device: str = "cpu",
retina_masks: bool = True,
imgsz: int | tuple[int, int] | None = 1024,
# threshold below which boxes will be filtered out
min_confidence_threshold: float = 0.4,
# discards all overlapping boxes with IoU > iou_threshold
max_iou_threshold: float = 0.9,
max_det: int = 1000,
max_retries: int = 5,
retry_delay_seconds: float = 0.1,
) -> Image:
Expand All @@ -90,100 +85,35 @@ def do_fastsam(
For usage of thresholds see:
github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
/ultralytics/utils/ops.py#L164

Args:
TODO
min_confidence_threshold (float, optional): The minimum confidence score
that a detection must meet or exceed to be considered valid. Detections
below this threshold will not be marked. Defaults to 0.00.
max_iou_threshold (float, optional): The maximum allowed Intersection over
Union (IoU) value for overlapping detections. Detections that exceed this
IoU threshold are considered for suppression, keeping only the
detection with the highest confidence. Defaults to 0.05.
"""
model = FastSAM(model_name)

imgsz = imgsz or image.size

# Run inference on image
everything_results = model(
image,
device=device,
retina_masks=retina_masks,
imgsz=imgsz,
conf=min_confidence_threshold,
iou=max_iou_threshold,
max_det=max_det,
)

# Prepare a Prompt Process object
prompt_process = FastSAMPrompt(image, everything_results, device="cpu")

# Everything prompt
annotations = prompt_process.everything_prompt()

# TODO: support other modes once issues are fixed
# https://github.com/ultralytics/ultralytics/issues/13218#issuecomment-2142960103

# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
# annotations = prompt_process.box_prompt(bbox=[200, 200, 300, 300])

# Text prompt
# annotations = prompt_process.text_prompt(text='a photo of a dog')

# Point prompt
# points default [[0,0]] [[x1,y1],[x2,y2]]
# point_label default [0] [1,0] 0:background, 1:foreground
# annotations = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])

assert len(annotations) == 1, len(annotations)
annotation = annotations[0]

# hide original image
annotation.orig_img = np.ones(annotation.orig_img.shape)

# TODO: in memory, e.g. with prompt_process.fast_show_mask()
with TemporaryDirectory() as tmp_dir:
# Force the output format to PNG to prevent JPEG compression artefacts
annotation.path = annotation.path.replace(".jpg", ".png")
prompt_process.plot(
[annotation],
tmp_dir,
with_contours=False,
retina=False,
assert len(everything_results) == 1, len(everything_results)
annotation = everything_results[0]

segmented_image = Image.fromarray(
annotation.plot(
img=np.ones(annotation.orig_img.shape, dtype=annotation.orig_img.dtype),
kpt_line=False,
labels=False,
boxes=False,
probs=False,
color_mode="instance",
)
result_name = os.path.basename(annotation.path)
logger.info(f"{annotation.path=}")
segmented_image_path = Path(tmp_dir) / result_name
segmented_image = Image.open(segmented_image_path)

# Ensure the image is fully loaded before deletion to avoid errors or incomplete operations,
# as some operating systems and file systems lock files during read or processing.
segmented_image.load()

# Attempt to delete the file with retries and delay
retries = 0

while retries < max_retries:
try:
os.remove(segmented_image_path)
break # If deletion succeeds, exit loop
except OSError as e:
if e.errno == errno.ENOENT: # File not found
break
else:
retries += 1
time.sleep(retry_delay_seconds)

if retries == max_retries:
logger.warning(f"Failed to delete {segmented_image_path}")
# Check if the dimensions of the original and segmented images differ
# XXX TODO this is a hack, this plotting code should be refactored, but the
# bug may exist in ultralytics, since they seem to resize as well; see:
# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/plotting.py#L238
# https://github.com/ultralytics/ultralytics/issues/561#issuecomment-1403079910
)

if image.size != segmented_image.size:
logger.warning(f"{image.size=} != {segmented_image.size=}, resizing...")
# Resize segmented_image to match original using nearest neighbor interpolation
segmented_image = segmented_image.resize(image.size, Image.NEAREST)

assert image.size == segmented_image.size, (image.size, segmented_image.size)
Expand All @@ -194,7 +124,6 @@ def do_fastsam(
def do_sam(
image: Image.Image,
model_name: str,
# TODO: add params
) -> Image.Image:
# Create SAMPredictor
overrides = dict(
Expand All @@ -207,20 +136,7 @@ def do_sam(
predictor = SAMPredictor(overrides=overrides)

# Segment with additional args
# results = predictor(source=image, crop_n_layers=1, points_stride=64)
results = predictor(
source=image,
# crop_n_layers=3,
# crop_overlap_ratio=0.5,
# crop_downscale_factor=1,
# point_grids=None,
# points_stride=12,
# points_batch_size=128,
# conf_thres=0.8,
# stability_score_thresh=0.95,
# stability_score_offset=0.95,
# crop_nms_thresh=0.8,
)
results = predictor(source=image)
mask_ims = results_to_mask_images(results)
segmented_image = colorize_masks(mask_ims)
return segmented_image
Expand All @@ -238,8 +154,7 @@ def results_to_mask_images(


def colorize_masks(masks: list[Image.Image]) -> Image.Image:
"""
Takes a list of PIL images containing binary masks and returns a new PIL.Image
"""Takes a list of PIL images containing binary masks and returns a new PIL.Image
where each mask is colored differently using a unique color for each mask.

Args:
Expand All @@ -249,15 +164,11 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
PIL.Image: A new image with each mask in a different color.
"""
if not masks:
return None # Return None if the list is empty
return None

# Assuming all masks are the same size, get dimensions
width, height = masks[0].size

# Create an empty array with 3 color channels (RGB)
result_image = np.zeros((height, width, 3), dtype=np.uint8)

# Generate unique colors using HSV color space
num_masks = len(masks)
colors = [
tuple(
Expand All @@ -271,17 +182,12 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
]

for idx, mask in enumerate(masks):
# Convert PIL Image to numpy array
mask_array = np.array(mask)

# Apply the color to the mask
for c in range(3):
# Only colorize where the mask is True (assuming mask is binary: 0 or 255)
result_image[:, :, c] += (mask_array / 255 * colors[idx][c]).astype(
np.uint8
)

# Convert the result back to a PIL image
return Image.fromarray(result_image)


Expand Down
3 changes: 2 additions & 1 deletion openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class SegmentationAdapter(str, Enum):
# Error reporting
ERROR_REPORTING_ENABLED: bool = True
ERROR_REPORTING_DSN: ClassVar = (
"https://dcf5d7889a3b4b47ae12a3af9ffcbeb7@app.glitchtip.com/3798"
# "https://dcf5d7889a3b4b47ae12a3af9ffcbeb7@app.glitchtip.com/3798"
"https://5d24fc5a2e674ea6b42275e5702499ce@app.glitchtip.com/8771",
)
ERROR_REPORTING_BRANCH: ClassVar = "main"

Expand Down
Loading
Loading