Skip to content

Commit ab502bb

Browse files
committed
Refactor stac.py to extract raster operations to raster.py
1 parent 30e6ff0 commit ab502bb

File tree

6 files changed

+3348
-1259
lines changed

6 files changed

+3348
-1259
lines changed

geospatial_tools/raster.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,5 +235,67 @@ def get_total_band_count(raster_file_list: list[Union[pathlib.Path, str]], logge
235235
for raster in raster_file_list:
236236
with rasterio.open(raster, "r") as raster_image:
237237
total_band_count += raster_image.count
238-
logger.info(f"Calculated a total of [{total_band_count}] bands")
238+
logger.info(f"Calculated a total of [{total_band_count}] bands")
239239
return total_band_count
240+
241+
242+
def create_merged_raster_bands_metadata(
243+
raster_file_list: list[Union[pathlib.Path, str]], logger: logging.Logger = LOGGER
244+
) -> dict:
245+
"""
246+
247+
Parameters
248+
----------
249+
raster_file_list
250+
logger
251+
252+
Returns
253+
-------
254+
255+
"""
256+
logger.info("Creating merged asset metadata")
257+
total_band_count = get_total_band_count(raster_file_list)
258+
with rasterio.open(raster_file_list[0]) as meta_source:
259+
meta = meta_source.meta
260+
meta.update(count=total_band_count)
261+
return meta
262+
263+
264+
def merge_raster_bands(
265+
merged_filename: Union[pathlib.Path, str],
266+
raster_file_list: list[Union[pathlib.Path, str]],
267+
metadata: dict = None,
268+
band_names: list[str] = None,
269+
logger: logging.Logger = LOGGER,
270+
) -> Optional[pathlib.Path]:
271+
if not metadata:
272+
metadata = create_merged_raster_bands_metadata(raster_file_list)
273+
merged_image_index = 1
274+
band_index = 0
275+
logger.info(f"Merging asset [{merged_filename}] ...")
276+
with rasterio.open(merged_filename, "w", **metadata) as merged_asset_image:
277+
for asset_sub_item in raster_file_list:
278+
asset_name = pathlib.Path(asset_sub_item).name
279+
logger.info(f"Writing band image: {asset_name}")
280+
with rasterio.open(asset_sub_item) as asset_band_image:
281+
num_of_bands = asset_band_image.count
282+
for asset_band_image_index in range(1, num_of_bands + 1):
283+
logger.info(
284+
f"Writing asset sub item band {asset_band_image_index} to merged index band {merged_image_index}"
285+
)
286+
merged_asset_image.write_band(merged_image_index, asset_band_image.read(asset_band_image_index))
287+
asset_description_index = asset_band_image_index - 1
288+
description = asset_band_image.descriptions[asset_description_index]
289+
if band_names:
290+
description = band_names[band_index]
291+
if num_of_bands > 1:
292+
description = f"{description}-{asset_band_image_index}"
293+
merged_asset_image.set_band_description(merged_image_index, description)
294+
merged_asset_image.update_tags(merged_image_index, **asset_band_image.tags(asset_band_image_index))
295+
merged_image_index += 1
296+
band_index += 1
297+
298+
if not merged_filename.exists():
299+
return None
300+
301+
return merged_filename

geospatial_tools/stac.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77

88
import pystac
99
import pystac_client
10-
import rasterio
1110
from planetary_computer import sign_inplace
1211
from pystac_client.exceptions import APIError
1312

1413
from geospatial_tools import geotools_types
15-
from geospatial_tools.raster import get_total_band_count, reproject_raster
14+
from geospatial_tools.raster import (
15+
create_merged_raster_bands_metadata,
16+
get_total_band_count,
17+
merge_raster_bands,
18+
reproject_raster,
19+
)
1620
from geospatial_tools.utils import create_logger, download_url
1721

1822
LOGGER = create_logger(__name__)
@@ -112,33 +116,14 @@ def merge_asset(self, base_directory: Optional[Union[str, pathlib.Path]] = None,
112116

113117
merged_filename = base_directory / f"{self.asset_id}_merged.tif"
114118

115-
total_band_count = self._get_asset_total_bands()
116-
117-
self.logger.info(total_band_count)
118-
119-
meta = self._create_merged_asset_metadata(total_band_count)
120-
121-
merged_image_index = 1
122-
band_index = 0
123-
self.logger.info(f"Merging asset [{self.asset_id}] ...")
124-
with rasterio.open(merged_filename, "w", **meta) as merged_asset_image:
125-
for asset_sub_item in self.list:
126-
self.logger.info(f"Writing band image: {asset_sub_item.item_id}")
127-
with rasterio.open(asset_sub_item.filename) as asset_band_image:
128-
num_of_bands = asset_band_image.count
129-
for asset_band_image_index in range(1, num_of_bands + 1):
130-
self.logger.info(f"writing asset sub item band {asset_band_image_index}")
131-
self.logger.info(f"writing merged index band {merged_image_index}")
132-
merged_asset_image.write_band(merged_image_index, asset_band_image.read(asset_band_image_index))
133-
description = self.bands[band_index]
134-
if num_of_bands > 1:
135-
description = f"{description}-{asset_band_image_index}"
136-
merged_asset_image.set_band_description(merged_image_index, description)
137-
merged_asset_image.update_tags(
138-
merged_image_index, **asset_band_image.tags(asset_band_image_index)
139-
)
140-
merged_image_index += 1
141-
band_index += 1
119+
asset_filename_list = [asset.filename for asset in self.list]
120+
121+
meta = self._create_merged_asset_metadata()
122+
123+
merge_raster_bands(
124+
merged_filename=merged_filename, raster_file_list=asset_filename_list, metadata=meta, band_names=self.bands
125+
)
126+
142127
if merged_filename.exists():
143128
self.logger.info(f"Asset [{self.asset_id}] merged successfully")
144129
self.logger.info(f"Asset location : [{merged_filename}]")
@@ -189,11 +174,10 @@ def delete_reprojected_asset(self):
189174
self.logger.info(f"Deleting reprojected asset file for [{self.reprojected_asset_path}]")
190175
self.reprojected_asset_path.unlink()
191176

192-
def _create_merged_asset_metadata(self, total_band_count):
177+
def _create_merged_asset_metadata(self):
193178
self.logger.info("Creating merged asset metadata")
194-
with rasterio.open(self.list[0].filename) as meta_source:
195-
meta = meta_source.meta
196-
meta.update(count=total_band_count)
179+
file_list = [asset.filename for asset in self.list]
180+
meta = create_merged_raster_bands_metadata(file_list)
197181
return meta
198182

199183
def _get_asset_total_bands(self):

0 commit comments

Comments
 (0)