@@ -1071,15 +1071,14 @@ def set_video(
1071
1071
step (int, optional): The step size. Defaults to 1.
1072
1072
frame_rate (Optional[int], optional): The frame rate. Defaults to None.
1073
1073
"""
1074
- import tempfile
1075
1074
1076
1075
if isinstance (video_path , str ):
1077
1076
if video_path .startswith ("http" ):
1078
1077
video_path = common .download_file (video_path )
1079
1078
if os .path .isfile (video_path ):
1080
1079
1081
1080
if output_dir is None :
1082
- output_dir = tempfile . mkdtemp ()
1081
+ output_dir = common . make_temp_dir ()
1083
1082
if not os .path .exists (output_dir ):
1084
1083
os .makedirs (output_dir )
1085
1084
print (f"Output directory: { output_dir } " )
@@ -1088,6 +1087,14 @@ def set_video(
1088
1087
)
1089
1088
1090
1089
elif os .path .isdir (video_path ):
1090
+ files = sorted (os .listdir (video_path ))
1091
+ if len (files ) == 0 :
1092
+ raise ValueError (f"No files found in { video_path } ." )
1093
+ elif files [0 ].endswith (".tif" ):
1094
+ self ._tif_source = os .path .join (video_path , files [0 ])
1095
+ self ._tif_dir = video_path
1096
+ self ._tif_names = files
1097
+ video_path = common .geotiff_to_jpg_batch (video_path )
1091
1098
output_dir = video_path
1092
1099
1093
1100
if not os .path .exists (video_path ):
@@ -1189,7 +1196,9 @@ def save_video_segments(self, output_dir: str, img_ext: str = "png") -> None:
1189
1196
"""
1190
1197
from PIL import Image
1191
1198
1192
- def save_image_from_dict (data , output_path = "output_image.png" ):
1199
+ def save_image_from_dict (
1200
+ data , output_path = "output_image.png" , crs_source = None , ** kwargs
1201
+ ):
1193
1202
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
1194
1203
array_shape = next (iter (data .values ())).shape [1 :]
1195
1204
@@ -1201,26 +1210,40 @@ def save_image_from_dict(data, output_path="output_image.png"):
1201
1210
# Assign the key value wherever the boolean array is True
1202
1211
output_array [array [0 ]] = key
1203
1212
1204
- # Convert the output array to a PIL image
1205
- image = Image .fromarray (output_array )
1213
+ if crs_source is None :
1214
+ # Convert the output array to a PIL image
1215
+ image = Image .fromarray (output_array )
1206
1216
1207
- # Save the image
1208
- image .save (output_path )
1209
-
1210
- if not os . path . exists ( output_dir ):
1211
- os . makedirs ( output_dir )
1217
+ # Save the image
1218
+ image .save (output_path )
1219
+ else :
1220
+ output_path = output_path . replace ( ".png" , ".tif" )
1221
+ common . array_to_image ( output_array , output_path , crs_source , ** kwargs )
1212
1222
1213
1223
num_frames = len (self .video_segments )
1214
1224
num_digits = len (str (num_frames ))
1215
1225
1226
+ if hasattr (self , "_tif_source" ) and self ._tif_source .endswith (".tif" ):
1227
+ crs_source = self ._tif_source
1228
+ filenames = self ._tif_names
1229
+ else :
1230
+ crs_source = None
1231
+ filenames = None
1232
+
1233
+ if not os .path .exists (output_dir ):
1234
+ os .makedirs (output_dir )
1235
+
1216
1236
# Initialize the tqdm progress bar
1217
1237
for frame_idx , video_segment in tqdm (
1218
1238
self .video_segments .items (), desc = "Rendering frames" , total = num_frames
1219
1239
):
1220
- output_path = os .path .join (
1221
- output_dir , f"{ str (frame_idx ).zfill (num_digits )} .{ img_ext } "
1222
- )
1223
- save_image_from_dict (video_segment , output_path )
1240
+ if filenames is None :
1241
+ output_path = os .path .join (
1242
+ output_dir , f"{ str (frame_idx ).zfill (num_digits )} .{ img_ext } "
1243
+ )
1244
+ else :
1245
+ output_path = os .path .join (output_dir , filenames [frame_idx ])
1246
+ save_image_from_dict (video_segment , output_path , crs_source )
1224
1247
1225
1248
def save_video_segments_blended (
1226
1249
self ,
@@ -1390,7 +1413,10 @@ def show_box(box, ax):
1390
1413
prompts = self ._convert_prompts (prompts )
1391
1414
video_dir = self .video_path
1392
1415
frame_names = self ._frame_names
1393
- plt .figure (figsize = figsize )
1416
+ fig = plt .figure (figsize = figsize )
1417
+ fig .canvas .toolbar_visible = True
1418
+ fig .canvas .header_visible = False
1419
+ fig .canvas .footer_visible = True
1394
1420
plt .title (f"frame { frame_idx } " )
1395
1421
plt .imshow (Image .open (os .path .join (video_dir , frame_names [frame_idx ])))
1396
1422
@@ -1406,3 +1432,5 @@ def show_box(box, ax):
1406
1432
show_box (box , plt .gca ())
1407
1433
if mask is not None :
1408
1434
show_mask (mask , plt .gca (), obj_id = obj_id )
1435
+
1436
+ plt .show ()
0 commit comments