Skip to content

Commit 1a35669

Browse files
committed
Add support for GeoTIFF
1 parent 7e5c754 commit 1a35669

File tree

2 files changed

+123
-15
lines changed

2 files changed

+123
-15
lines changed

samgeo/common.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3433,3 +3433,83 @@ def on_slider_change(change):
34333433

34343434
# Show the plot
34353435
plt.show()
3436+
3437+
3438+
def make_temp_dir(**kwargs) -> str:
3439+
"""Create a temporary directory and return the path.
3440+
3441+
Returns:
3442+
str: The path to the temporary directory.
3443+
"""
3444+
import tempfile
3445+
3446+
temp_dir = tempfile.mkdtemp(**kwargs)
3447+
return temp_dir
3448+
3449+
3450+
def geotiff_to_jpg(geotiff_path: str, output_path: str) -> None:
3451+
"""Convert a GeoTIFF file to a JPG file.
3452+
3453+
Args:
3454+
geotiff_path (str): The path to the input GeoTIFF file.
3455+
output_path (str): The path to the output JPG file.
3456+
"""
3457+
3458+
from PIL import Image
3459+
3460+
# Open the GeoTIFF file
3461+
with rasterio.open(geotiff_path) as src:
3462+
# Read the first band (for grayscale) or all bands
3463+
array = src.read()
3464+
3465+
# If the array has more than 3 bands, reduce it to the first 3 (RGB)
3466+
if array.shape[0] >= 3:
3467+
array = array[:3, :, :] # Select the first 3 bands (R, G, B)
3468+
elif array.shape[0] == 1:
3469+
# For single-band images, repeat the band to create a grayscale RGB
3470+
array = np.repeat(array, 3, axis=0)
3471+
3472+
# Transpose the array from (bands, height, width) to (height, width, bands)
3473+
array = np.transpose(array, (1, 2, 0))
3474+
3475+
# Normalize the array to 8-bit (0-255) range for JPG
3476+
array = array.astype(np.float32)
3477+
array -= array.min()
3478+
array /= array.max()
3479+
array *= 255
3480+
array = array.astype(np.uint8)
3481+
3482+
# Convert to a PIL Image and save as JPG
3483+
image = Image.fromarray(array)
3484+
image.save(output_path)
3485+
3486+
3487+
def geotiff_to_jpg_batch(input_folder: str, output_folder: str = None) -> str:
3488+
"""Convert all GeoTIFF files in a folder to JPG files.
3489+
3490+
Args:
3491+
input_folder (str): The path to the folder containing GeoTIFF files.
3492+
output_folder (str): The path to the folder to save the output JPG files.
3493+
3494+
Returns:
3495+
str: The path to the output folder containing the JPG files.
3496+
"""
3497+
3498+
if output_folder is None:
3499+
output_folder = make_temp_dir()
3500+
3501+
if not os.path.exists(output_folder):
3502+
os.makedirs(output_folder)
3503+
3504+
geotiff_files = [
3505+
f for f in os.listdir(input_folder) if f.endswith(".tif") or f.endswith(".tiff")
3506+
]
3507+
3508+
# Initialize tqdm progress bar
3509+
for filename in tqdm(geotiff_files, desc="Converting GeoTIFF to JPG"):
3510+
geotiff_path = os.path.join(input_folder, filename)
3511+
jpg_filename = os.path.splitext(filename)[0] + ".jpg"
3512+
output_path = os.path.join(output_folder, jpg_filename)
3513+
geotiff_to_jpg(geotiff_path, output_path)
3514+
3515+
return output_folder

samgeo/samgeo2.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,15 +1071,14 @@ def set_video(
10711071
step (int, optional): The step size. Defaults to 1.
10721072
frame_rate (Optional[int], optional): The frame rate. Defaults to None.
10731073
"""
1074-
import tempfile
10751074

10761075
if isinstance(video_path, str):
10771076
if video_path.startswith("http"):
10781077
video_path = common.download_file(video_path)
10791078
if os.path.isfile(video_path):
10801079

10811080
if output_dir is None:
1082-
output_dir = tempfile.mkdtemp()
1081+
output_dir = common.make_temp_dir()
10831082
if not os.path.exists(output_dir):
10841083
os.makedirs(output_dir)
10851084
print(f"Output directory: {output_dir}")
@@ -1088,6 +1087,14 @@ def set_video(
10881087
)
10891088

10901089
elif os.path.isdir(video_path):
1090+
files = sorted(os.listdir(video_path))
1091+
if len(files) == 0:
1092+
raise ValueError(f"No files found in {video_path}.")
1093+
elif files[0].endswith(".tif"):
1094+
self._tif_source = os.path.join(video_path, files[0])
1095+
self._tif_dir = video_path
1096+
self._tif_names = files
1097+
video_path = common.geotiff_to_jpg_batch(video_path)
10911098
output_dir = video_path
10921099

10931100
if not os.path.exists(video_path):
@@ -1189,7 +1196,9 @@ def save_video_segments(self, output_dir: str, img_ext: str = "png") -> None:
11891196
"""
11901197
from PIL import Image
11911198

1192-
def save_image_from_dict(data, output_path="output_image.png"):
1199+
def save_image_from_dict(
1200+
data, output_path="output_image.png", crs_source=None, **kwargs
1201+
):
11931202
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
11941203
array_shape = next(iter(data.values())).shape[1:]
11951204

@@ -1201,26 +1210,40 @@ def save_image_from_dict(data, output_path="output_image.png"):
12011210
# Assign the key value wherever the boolean array is True
12021211
output_array[array[0]] = key
12031212

1204-
# Convert the output array to a PIL image
1205-
image = Image.fromarray(output_array)
1213+
if crs_source is None:
1214+
# Convert the output array to a PIL image
1215+
image = Image.fromarray(output_array)
12061216

1207-
# Save the image
1208-
image.save(output_path)
1209-
1210-
if not os.path.exists(output_dir):
1211-
os.makedirs(output_dir)
1217+
# Save the image
1218+
image.save(output_path)
1219+
else:
1220+
output_path = output_path.replace(".png", ".tif")
1221+
common.array_to_image(output_array, output_path, crs_source, **kwargs)
12121222

12131223
num_frames = len(self.video_segments)
12141224
num_digits = len(str(num_frames))
12151225

1226+
if hasattr(self, "_tif_source") and self._tif_source.endswith(".tif"):
1227+
crs_source = self._tif_source
1228+
filenames = self._tif_names
1229+
else:
1230+
crs_source = None
1231+
filenames = None
1232+
1233+
if not os.path.exists(output_dir):
1234+
os.makedirs(output_dir)
1235+
12161236
# Initialize the tqdm progress bar
12171237
for frame_idx, video_segment in tqdm(
12181238
self.video_segments.items(), desc="Rendering frames", total=num_frames
12191239
):
1220-
output_path = os.path.join(
1221-
output_dir, f"{str(frame_idx).zfill(num_digits)}.{img_ext}"
1222-
)
1223-
save_image_from_dict(video_segment, output_path)
1240+
if filenames is None:
1241+
output_path = os.path.join(
1242+
output_dir, f"{str(frame_idx).zfill(num_digits)}.{img_ext}"
1243+
)
1244+
else:
1245+
output_path = os.path.join(output_dir, filenames[frame_idx])
1246+
save_image_from_dict(video_segment, output_path, crs_source)
12241247

12251248
def save_video_segments_blended(
12261249
self,
@@ -1390,7 +1413,10 @@ def show_box(box, ax):
13901413
prompts = self._convert_prompts(prompts)
13911414
video_dir = self.video_path
13921415
frame_names = self._frame_names
1393-
plt.figure(figsize=figsize)
1416+
fig = plt.figure(figsize=figsize)
1417+
fig.canvas.toolbar_visible = True
1418+
fig.canvas.header_visible = False
1419+
fig.canvas.footer_visible = True
13941420
plt.title(f"frame {frame_idx}")
13951421
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
13961422

@@ -1406,3 +1432,5 @@ def show_box(box, ax):
14061432
show_box(box, plt.gca())
14071433
if mask is not None:
14081434
show_mask(mask, plt.gca(), obj_id=obj_id)
1435+
1436+
plt.show()

0 commit comments

Comments
 (0)