Skip to content

Commit 7e5c754

Browse files
committed
Add show_prompts and show_images functions
1 parent e4b815f commit 7e5c754

File tree

1 file changed

+102
-2
lines changed

1 file changed

+102
-2
lines changed

samgeo/samgeo2.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,6 +1226,7 @@ def save_video_segments_blended(
12261226
self,
12271227
output_dir: str,
12281228
img_ext: str = "png",
1229+
alpha: float = 0.6,
12291230
dpi: int = 200,
12301231
frame_stride: int = 1,
12311232
output_video: Optional[str] = None,
@@ -1236,6 +1237,8 @@ def save_video_segments_blended(
12361237
Args:
12371238
output_dir (str): The directory to save the output images.
12381239
img_ext (str): The file extension for the output images. Defaults to "png".
1240+
alpha (float): The alpha value for the blended masks. Defaults to 0.6.
1241+
12391242
dpi (int): The DPI (dots per inch) for the output images. Defaults to 200.
12401243
frame_stride (int): The stride for selecting frames to save. Defaults to 1.
12411244
output_video (Optional[str]): The path to the output video file. Defaults to None.
@@ -1246,11 +1249,11 @@ def save_video_segments_blended(
12461249

12471250
def show_mask(mask, ax, obj_id=None, random_color=False):
12481251
if random_color:
1249-
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
1252+
color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
12501253
else:
12511254
cmap = plt.get_cmap("tab10")
12521255
cmap_idx = 0 if obj_id is None else obj_id
1253-
color = np.array([*cmap(cmap_idx)[:3], 0.6])
1256+
color = np.array([*cmap(cmap_idx)[:3], alpha])
12541257
h, w = mask.shape[-2:]
12551258
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
12561259
ax.imshow(mask_image)
@@ -1306,3 +1309,100 @@ def show_mask(mask, ax, obj_id=None, random_color=False):
13061309

13071310
if output_video is not None:
13081311
common.images_to_video(output_dir, output_video, fps=fps)
1312+
1313+
def show_images(self, path: str = None) -> None:
1314+
"""Show the images in the video.
1315+
1316+
Args:
1317+
path (str, optional): The path to the images. Defaults to None.
1318+
"""
1319+
if path is None:
1320+
path = self.video_path
1321+
1322+
if path is not None:
1323+
common.show_image_gui(path)
1324+
1325+
def show_prompts(
1326+
self,
1327+
prompts: Dict[int, Any],
1328+
frame_idx: int = 0,
1329+
mask: Any = None,
1330+
random_color: bool = False,
1331+
figsize: Tuple[int, int] = (9, 6),
1332+
) -> None:
1333+
"""Show the prompts on the image.
1334+
1335+
Args:
1336+
prompts (Dict[int, Any]): A dictionary containing the prompts with
1337+
points and labels.
1338+
frame_idx (int, optional): The frame index. Defaults to 0.
1339+
mask (Any, optional): The mask. Defaults to None.
1340+
random_color (bool, optional): Whether to use random colors for the
1341+
masks. Defaults to False.
1342+
figsize (Tuple[int, int], optional): The figure size. Defaults to (9, 6).
1343+
1344+
"""
1345+
1346+
from PIL import Image
1347+
1348+
def show_mask(mask, ax, obj_id=None, random_color=random_color):
1349+
if random_color:
1350+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
1351+
else:
1352+
cmap = plt.get_cmap("tab10")
1353+
cmap_idx = 0 if obj_id is None else obj_id
1354+
color = np.array([*cmap(cmap_idx)[:3], 0.6])
1355+
h, w = mask.shape[-2:]
1356+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
1357+
ax.imshow(mask_image)
1358+
1359+
def show_points(coords, labels, ax, marker_size=200):
1360+
pos_points = coords[labels == 1]
1361+
neg_points = coords[labels == 0]
1362+
ax.scatter(
1363+
pos_points[:, 0],
1364+
pos_points[:, 1],
1365+
color="green",
1366+
marker="*",
1367+
s=marker_size,
1368+
edgecolor="white",
1369+
linewidth=1.25,
1370+
)
1371+
ax.scatter(
1372+
neg_points[:, 0],
1373+
neg_points[:, 1],
1374+
color="red",
1375+
marker="*",
1376+
s=marker_size,
1377+
edgecolor="white",
1378+
linewidth=1.25,
1379+
)
1380+
1381+
def show_box(box, ax):
1382+
x0, y0 = box[0], box[1]
1383+
w, h = box[2] - box[0], box[3] - box[1]
1384+
ax.add_patch(
1385+
plt.Rectangle(
1386+
(x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2
1387+
)
1388+
)
1389+
1390+
prompts = self._convert_prompts(prompts)
1391+
video_dir = self.video_path
1392+
frame_names = self._frame_names
1393+
plt.figure(figsize=figsize)
1394+
plt.title(f"frame {frame_idx}")
1395+
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
1396+
1397+
for obj_id, prompt in prompts.items():
1398+
points = prompt.get("points", None)
1399+
labels = prompt.get("labels", None)
1400+
box = prompt.get("box", None)
1401+
anno_frame_idx = prompt.get("frame_idx", None)
1402+
if anno_frame_idx == frame_idx:
1403+
if points is not None:
1404+
show_points(points, labels, plt.gca())
1405+
if box is not None:
1406+
show_box(box, plt.gca())
1407+
if mask is not None:
1408+
show_mask(mask, plt.gca(), obj_id=obj_id)

0 commit comments

Comments
 (0)