Skip to content

Commit c38d149

Browse files
committed
Save prediction result
1 parent b37bfed commit c38d149

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

samgeo/samgeo2.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,7 @@ def set_video(
10901090
raise ValueError("Input video_path must be a string.")
10911091

10921092
self.video_path = output_dir
1093+
self._num_images = len(os.listdir(output_dir))
10931094
self.inference_state = self.predictor.init_state(video_path=output_dir)
10941095

10951096
def predict_video(
@@ -1105,6 +1106,27 @@ def predict_video(
11051106
output_dir (Optional[str]): The directory to save the output images. Defaults to None.
11061107
img_ext (str): The file extension for the output images. Defaults to "png".
11071108
"""
1109+
1110+
from PIL import Image
1111+
1112+
def save_image_from_dict(data, output_path="output_image.png"):
1113+
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
1114+
array_shape = next(iter(data.values())).shape[1:]
1115+
1116+
# Initialize an empty array with the same shape as the arrays in the dictionary, filled with zeros
1117+
output_array = np.zeros(array_shape, dtype=np.uint8)
1118+
1119+
# Iterate over each key and array in the dictionary
1120+
for key, array in data.items():
1121+
# Assign the key value wherever the boolean array is True
1122+
output_array[array[0]] = key
1123+
1124+
# Convert the output array to a PIL image
1125+
image = Image.fromarray(output_array)
1126+
1127+
# Save the image
1128+
image.save(output_path)
1129+
11081130
prompts = self._convert_prompts(prompts)
11091131
predictor = self.predictor
11101132
inference_state = self.inference_state
@@ -1121,6 +1143,13 @@ def predict_video(
11211143
)
11221144

11231145
video_segments = {}
1146+
num_frames = self._num_images
1147+
num_digits = len(str(num_frames))
1148+
1149+
if output_dir is not None:
1150+
if not os.path.exists(output_dir):
1151+
os.makedirs(output_dir)
1152+
11241153
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
11251154
inference_state
11261155
):
@@ -1129,10 +1158,16 @@ def predict_video(
11291158
for i, out_obj_id in enumerate(out_obj_ids)
11301159
}
11311160

1161+
if output_dir is not None:
1162+
output_path = os.path.join(
1163+
output_dir, f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
1164+
)
1165+
save_image_from_dict(video_segments[out_frame_idx], output_path)
1166+
11321167
self.video_segments = video_segments
11331168

1134-
if output_dir is not None:
1135-
self.save_video_segments(output_dir, img_ext)
1169+
# if output_dir is not None:
1170+
# self.save_video_segments(output_dir, img_ext)
11361171

11371172
def save_video_segments(self, output_dir: str, img_ext: str = "png") -> None:
11381173
"""Save the video segments to the output directory.

0 commit comments

Comments
 (0)