@@ -1090,6 +1090,7 @@ def set_video(
1090
1090
raise ValueError ("Input video_path must be a string." )
1091
1091
1092
1092
self .video_path = output_dir
1093
+ self ._num_images = len (os .listdir (output_dir ))
1093
1094
self .inference_state = self .predictor .init_state (video_path = output_dir )
1094
1095
1095
1096
def predict_video (
@@ -1105,6 +1106,27 @@ def predict_video(
1105
1106
output_dir (Optional[str]): The directory to save the output images. Defaults to None.
1106
1107
img_ext (str): The file extension for the output images. Defaults to "png".
1107
1108
"""
1109
+
1110
+ from PIL import Image
1111
+
1112
+ def save_image_from_dict (data , output_path = "output_image.png" ):
1113
+ # Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
1114
+ array_shape = next (iter (data .values ())).shape [1 :]
1115
+
1116
+ # Initialize an empty array with the same shape as the arrays in the dictionary, filled with zeros
1117
+ output_array = np .zeros (array_shape , dtype = np .uint8 )
1118
+
1119
+ # Iterate over each key and array in the dictionary
1120
+ for key , array in data .items ():
1121
+ # Assign the key value wherever the boolean array is True
1122
+ output_array [array [0 ]] = key
1123
+
1124
+ # Convert the output array to a PIL image
1125
+ image = Image .fromarray (output_array )
1126
+
1127
+ # Save the image
1128
+ image .save (output_path )
1129
+
1108
1130
prompts = self ._convert_prompts (prompts )
1109
1131
predictor = self .predictor
1110
1132
inference_state = self .inference_state
@@ -1121,6 +1143,13 @@ def predict_video(
1121
1143
)
1122
1144
1123
1145
video_segments = {}
1146
+ num_frames = self ._num_images
1147
+ num_digits = len (str (num_frames ))
1148
+
1149
+ if output_dir is not None :
1150
+ if not os .path .exists (output_dir ):
1151
+ os .makedirs (output_dir )
1152
+
1124
1153
for out_frame_idx , out_obj_ids , out_mask_logits in predictor .propagate_in_video (
1125
1154
inference_state
1126
1155
):
@@ -1129,10 +1158,16 @@ def predict_video(
1129
1158
for i , out_obj_id in enumerate (out_obj_ids )
1130
1159
}
1131
1160
1161
+ if output_dir is not None :
1162
+ output_path = os .path .join (
1163
+ output_dir , f"{ str (out_frame_idx ).zfill (num_digits )} .{ img_ext } "
1164
+ )
1165
+ save_image_from_dict (video_segments [out_frame_idx ], output_path )
1166
+
1132
1167
self .video_segments = video_segments
1133
1168
1134
- if output_dir is not None :
1135
- self .save_video_segments (output_dir , img_ext )
1169
+ # if output_dir is not None:
1170
+ # self.save_video_segments(output_dir, img_ext)
1136
1171
1137
1172
def save_video_segments (self , output_dir : str , img_ext : str = "png" ) -> None :
1138
1173
"""Save the video segments to the output directory.
0 commit comments