Skip to content

Commit e4b815f

Browse files
committed
Add save video prediction blended
1 parent c38d149 commit e4b815f

File tree

2 files changed

+104
-5
lines changed

2 files changed

+104
-5
lines changed

samgeo/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3239,7 +3239,7 @@ def images_to_video(
32393239
height, width, _ = frame.shape
32403240
video_size = (width, height)
32413241

3242-
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Define the codec for mp4
3242+
fourcc = cv2.VideoWriter_fourcc(*"avc1") # Define the codec for mp4
32433243
video_writer = cv2.VideoWriter(output_video, fourcc, fps, video_size)
32443244

32453245
for image_path in images:

samgeo/samgeo2.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import cv2
33
import torch
44
import numpy as np
5+
import matplotlib.pyplot as plt
56
from PIL.Image import Image
7+
from tqdm import tqdm
68
from typing import Any, Dict, List, Optional, Tuple, Union
79
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
810
from sam2.sam2_image_predictor import SAM2ImagePredictor
@@ -1047,6 +1049,10 @@ def _convert_prompts(self, prompts: Dict[int, Any]) -> Dict[int, Any]:
10471049
# Convert labels to np.int32 array
10481050
if "labels" in value:
10491051
value["labels"] = np.array(value["labels"], dtype=np.int32)
1052+
# Convert box to np.float32 array
1053+
if "box" in value:
1054+
value["box"] = np.array(value["box"], dtype=np.float32)
1055+
10501056
return prompts
10511057

10521058
def set_video(
@@ -1091,6 +1097,7 @@ def set_video(
10911097

10921098
self.video_path = output_dir
10931099
self._num_images = len(os.listdir(output_dir))
1100+
self._frame_names = sorted(os.listdir(output_dir))
10941101
self.inference_state = self.predictor.init_state(video_path=output_dir)
10951102

10961103
def predict_video(
@@ -1131,15 +1138,19 @@ def save_image_from_dict(data, output_path="output_image.png"):
11311138
predictor = self.predictor
11321139
inference_state = self.inference_state
11331140
for obj_id, prompt in prompts.items():
1134-
points = prompt["points"]
1135-
labels = prompt["labels"]
1136-
frame_idx = prompt["frame_idx"]
1141+
1142+
points = prompt.get("points", None)
1143+
labels = prompt.get("labels", None)
1144+
box = prompt.get("box", None)
1145+
frame_idx = prompt.get("frame_idx", None)
1146+
11371147
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
11381148
inference_state=inference_state,
11391149
frame_idx=frame_idx,
11401150
obj_id=obj_id,
11411151
points=points,
11421152
labels=labels,
1153+
box=box,
11431154
)
11441155

11451156
video_segments = {}
@@ -1202,8 +1213,96 @@ def save_image_from_dict(data, output_path="output_image.png"):
12021213
num_frames = len(self.video_segments)
12031214
num_digits = len(str(num_frames))
12041215

1205-
for frame_idx, video_segment in self.video_segments.items():
1216+
# Initialize the tqdm progress bar
1217+
for frame_idx, video_segment in tqdm(
1218+
self.video_segments.items(), desc="Rendering frames", total=num_frames
1219+
):
12061220
output_path = os.path.join(
12071221
output_dir, f"{str(frame_idx).zfill(num_digits)}.{img_ext}"
12081222
)
12091223
save_image_from_dict(video_segment, output_path)
1224+
1225+
def save_video_segments_blended(
1226+
self,
1227+
output_dir: str,
1228+
img_ext: str = "png",
1229+
dpi: int = 200,
1230+
frame_stride: int = 1,
1231+
output_video: Optional[str] = None,
1232+
fps: int = 30,
1233+
) -> None:
1234+
"""Save blended video segments to the output directory and optionally create a video.
1235+
1236+
Args:
1237+
output_dir (str): The directory to save the output images.
1238+
img_ext (str): The file extension for the output images. Defaults to "png".
1239+
dpi (int): The DPI (dots per inch) for the output images. Defaults to 200.
1240+
frame_stride (int): The stride for selecting frames to save. Defaults to 1.
1241+
output_video (Optional[str]): The path to the output video file. Defaults to None.
1242+
fps (int): The frames per second for the output video. Defaults to 30.
1243+
"""
1244+
1245+
from PIL import Image
1246+
1247+
def show_mask(mask, ax, obj_id=None, random_color=False):
1248+
if random_color:
1249+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
1250+
else:
1251+
cmap = plt.get_cmap("tab10")
1252+
cmap_idx = 0 if obj_id is None else obj_id
1253+
color = np.array([*cmap(cmap_idx)[:3], 0.6])
1254+
h, w = mask.shape[-2:]
1255+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
1256+
ax.imshow(mask_image)
1257+
1258+
if not os.path.exists(output_dir):
1259+
os.makedirs(output_dir)
1260+
1261+
plt.close("all")
1262+
1263+
video_segments = self.video_segments
1264+
video_dir = self.video_path
1265+
frame_names = self._frame_names
1266+
num_frames = len(frame_names)
1267+
num_digits = len(str(num_frames))
1268+
1269+
# Initialize the tqdm progress bar
1270+
for out_frame_idx in tqdm(
1271+
range(0, len(frame_names), frame_stride), desc="Rendering frames"
1272+
):
1273+
image = Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))
1274+
1275+
# Get original image dimensions
1276+
w, h = image.size
1277+
1278+
# Set DPI and calculate figure size based on the original image dimensions
1279+
figsize = (
1280+
w / dpi,
1281+
h / dpi,
1282+
)
1283+
figsize = (
1284+
figsize[0] * 1.3,
1285+
figsize[1] * 1.3,
1286+
)
1287+
1288+
# Create a figure with the exact size and DPI
1289+
fig = plt.figure(figsize=figsize, dpi=dpi)
1290+
1291+
# Disable axis to prevent whitespace
1292+
plt.axis("off")
1293+
1294+
# Display the original image
1295+
plt.imshow(image)
1296+
1297+
# Overlay masks for each object ID
1298+
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
1299+
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
1300+
1301+
# Save the figure with no borders or extra padding
1302+
filename = f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
1303+
filepath = os.path.join(output_dir, filename)
1304+
plt.savefig(filepath, dpi=dpi, pad_inches=0, bbox_inches="tight")
1305+
plt.close(fig)
1306+
1307+
if output_video is not None:
1308+
common.images_to_video(output_dir, output_video, fps=fps)

0 commit comments

Comments
 (0)