|
2 | 2 | import cv2
|
3 | 3 | import torch
|
4 | 4 | import numpy as np
|
| 5 | +import matplotlib.pyplot as plt |
5 | 6 | from PIL.Image import Image
|
| 7 | +from tqdm import tqdm |
6 | 8 | from typing import Any, Dict, List, Optional, Tuple, Union
|
7 | 9 | from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
8 | 10 | from sam2.sam2_image_predictor import SAM2ImagePredictor
|
@@ -1047,6 +1049,10 @@ def _convert_prompts(self, prompts: Dict[int, Any]) -> Dict[int, Any]:
|
1047 | 1049 | # Convert labels to np.int32 array
|
1048 | 1050 | if "labels" in value:
|
1049 | 1051 | value["labels"] = np.array(value["labels"], dtype=np.int32)
|
| 1052 | + # Convert box to np.float32 array |
| 1053 | + if "box" in value: |
| 1054 | + value["box"] = np.array(value["box"], dtype=np.float32) |
| 1055 | + |
1050 | 1056 | return prompts
|
1051 | 1057 |
|
1052 | 1058 | def set_video(
|
@@ -1091,6 +1097,7 @@ def set_video(
|
1091 | 1097 |
|
1092 | 1098 | self.video_path = output_dir
|
1093 | 1099 | self._num_images = len(os.listdir(output_dir))
|
| 1100 | + self._frame_names = sorted(os.listdir(output_dir)) |
1094 | 1101 | self.inference_state = self.predictor.init_state(video_path=output_dir)
|
1095 | 1102 |
|
1096 | 1103 | def predict_video(
|
@@ -1131,15 +1138,19 @@ def save_image_from_dict(data, output_path="output_image.png"):
|
1131 | 1138 | predictor = self.predictor
|
1132 | 1139 | inference_state = self.inference_state
|
1133 | 1140 | for obj_id, prompt in prompts.items():
|
1134 |
| - points = prompt["points"] |
1135 |
| - labels = prompt["labels"] |
1136 |
| - frame_idx = prompt["frame_idx"] |
| 1141 | + |
| 1142 | + points = prompt.get("points", None) |
| 1143 | + labels = prompt.get("labels", None) |
| 1144 | + box = prompt.get("box", None) |
| 1145 | + frame_idx = prompt.get("frame_idx", None) |
| 1146 | + |
1137 | 1147 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
1138 | 1148 | inference_state=inference_state,
|
1139 | 1149 | frame_idx=frame_idx,
|
1140 | 1150 | obj_id=obj_id,
|
1141 | 1151 | points=points,
|
1142 | 1152 | labels=labels,
|
| 1153 | + box=box, |
1143 | 1154 | )
|
1144 | 1155 |
|
1145 | 1156 | video_segments = {}
|
@@ -1202,8 +1213,96 @@ def save_image_from_dict(data, output_path="output_image.png"):
|
1202 | 1213 | num_frames = len(self.video_segments)
|
1203 | 1214 | num_digits = len(str(num_frames))
|
1204 | 1215 |
|
1205 |
| - for frame_idx, video_segment in self.video_segments.items(): |
| 1216 | + # Initialize the tqdm progress bar |
| 1217 | + for frame_idx, video_segment in tqdm( |
| 1218 | + self.video_segments.items(), desc="Rendering frames", total=num_frames |
| 1219 | + ): |
1206 | 1220 | output_path = os.path.join(
|
1207 | 1221 | output_dir, f"{str(frame_idx).zfill(num_digits)}.{img_ext}"
|
1208 | 1222 | )
|
1209 | 1223 | save_image_from_dict(video_segment, output_path)
|
| 1224 | + |
| 1225 | + def save_video_segments_blended( |
| 1226 | + self, |
| 1227 | + output_dir: str, |
| 1228 | + img_ext: str = "png", |
| 1229 | + dpi: int = 200, |
| 1230 | + frame_stride: int = 1, |
| 1231 | + output_video: Optional[str] = None, |
| 1232 | + fps: int = 30, |
| 1233 | + ) -> None: |
| 1234 | + """Save blended video segments to the output directory and optionally create a video. |
| 1235 | +
|
| 1236 | + Args: |
| 1237 | + output_dir (str): The directory to save the output images. |
| 1238 | + img_ext (str): The file extension for the output images. Defaults to "png". |
| 1239 | + dpi (int): The DPI (dots per inch) for the output images. Defaults to 200. |
| 1240 | + frame_stride (int): The stride for selecting frames to save. Defaults to 1. |
| 1241 | + output_video (Optional[str]): The path to the output video file. Defaults to None. |
| 1242 | + fps (int): The frames per second for the output video. Defaults to 30. |
| 1243 | + """ |
| 1244 | + |
| 1245 | + from PIL import Image |
| 1246 | + |
| 1247 | + def show_mask(mask, ax, obj_id=None, random_color=False): |
| 1248 | + if random_color: |
| 1249 | + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
| 1250 | + else: |
| 1251 | + cmap = plt.get_cmap("tab10") |
| 1252 | + cmap_idx = 0 if obj_id is None else obj_id |
| 1253 | + color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
| 1254 | + h, w = mask.shape[-2:] |
| 1255 | + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| 1256 | + ax.imshow(mask_image) |
| 1257 | + |
| 1258 | + if not os.path.exists(output_dir): |
| 1259 | + os.makedirs(output_dir) |
| 1260 | + |
| 1261 | + plt.close("all") |
| 1262 | + |
| 1263 | + video_segments = self.video_segments |
| 1264 | + video_dir = self.video_path |
| 1265 | + frame_names = self._frame_names |
| 1266 | + num_frames = len(frame_names) |
| 1267 | + num_digits = len(str(num_frames)) |
| 1268 | + |
| 1269 | + # Initialize the tqdm progress bar |
| 1270 | + for out_frame_idx in tqdm( |
| 1271 | + range(0, len(frame_names), frame_stride), desc="Rendering frames" |
| 1272 | + ): |
| 1273 | + image = Image.open(os.path.join(video_dir, frame_names[out_frame_idx])) |
| 1274 | + |
| 1275 | + # Get original image dimensions |
| 1276 | + w, h = image.size |
| 1277 | + |
| 1278 | + # Set DPI and calculate figure size based on the original image dimensions |
| 1279 | + figsize = ( |
| 1280 | + w / dpi, |
| 1281 | + h / dpi, |
| 1282 | + ) |
| 1283 | + figsize = ( |
| 1284 | + figsize[0] * 1.3, |
| 1285 | + figsize[1] * 1.3, |
| 1286 | + ) |
| 1287 | + |
| 1288 | + # Create a figure with the exact size and DPI |
| 1289 | + fig = plt.figure(figsize=figsize, dpi=dpi) |
| 1290 | + |
| 1291 | + # Disable axis to prevent whitespace |
| 1292 | + plt.axis("off") |
| 1293 | + |
| 1294 | + # Display the original image |
| 1295 | + plt.imshow(image) |
| 1296 | + |
| 1297 | + # Overlay masks for each object ID |
| 1298 | + for out_obj_id, out_mask in video_segments[out_frame_idx].items(): |
| 1299 | + show_mask(out_mask, plt.gca(), obj_id=out_obj_id) |
| 1300 | + |
| 1301 | + # Save the figure with no borders or extra padding |
| 1302 | + filename = f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}" |
| 1303 | + filepath = os.path.join(output_dir, filename) |
| 1304 | + plt.savefig(filepath, dpi=dpi, pad_inches=0, bbox_inches="tight") |
| 1305 | + plt.close(fig) |
| 1306 | + |
| 1307 | + if output_video is not None: |
| 1308 | + common.images_to_video(output_dir, output_video, fps=fps) |
0 commit comments