Skip to content

Commit 87f82aa

Browse files
authored
Merge pull request #27 from FrontierDevelopmentLab/bugfix/boundless-window
fix COG reader
2 parents e242376 + ba71db5 commit 87f82aa

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

providers/gcp/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def extract_patches():
157157
cloud_fs=fs,
158158
download_f=download_blob,
159159
task=task,
160-
method="first",
160+
method="max",
161161
resolution=archive_resolution,
162162
)
163163

src/satextractor/extractor/extractor.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,51 @@
66

77
import numpy as np
88
import rasterio
9+
from affine import Affine
910
from loguru import logger
1011
from osgeo import gdal
1112
from osgeo import osr
13+
from rasterio import warp
14+
from rasterio.crs import CRS
15+
from rasterio.enums import Resampling
1216
from rasterio.merge import merge as riomerge
1317
from satextractor.models import ExtractionTask
1418
from satextractor.models import Tile
1519

1620

21+
def get_window_union(
22+
tiles: List[Tile],
23+
ds: rasterio.io.DatasetReader,
24+
) -> rasterio.windows.Window:
25+
26+
"""Get the window union to read all tiles from the geotiff.
27+
28+
Args:
29+
tiles (List[Tile]): the tiles
30+
ds (rasterio.io.DatasetReader): the rasterio dataset to read (for the transform)
31+
32+
Returns:
33+
rasterio.windows.Window: The union of all tile windows.
34+
"""
35+
36+
windows = []
37+
38+
for tile in tiles:
39+
40+
bounds_arr_tile_crs = tile.bbox
41+
bounds_arr_rast_crs = warp.transform_bounds(
42+
CRS.from_epsg(tile.epsg),
43+
ds.crs,
44+
*bounds_arr_tile_crs,
45+
)
46+
47+
window = rasterio.windows.from_bounds(*bounds_arr_rast_crs, ds.transform)
48+
49+
windows.append(window)
50+
51+
return rasterio.windows.union(windows)
52+
53+
1754
def get_proj_win(tiles: List[Tile]) -> Tuple[int, int, int, int]:
1855
"""Get the projection bounds window of the tiles.
1956
@@ -48,6 +85,69 @@ def get_tile_pixel_coords(tiles: List[Tile], raster_file: str) -> List[Tuple[int
4885
return list(zip(rows, cols))
4986

5087

88+
def download_and_extract_tiles_window_COG(
89+
fs: Any,
90+
task: ExtractionTask,
91+
resolution: int,
92+
) -> List[str]:
93+
"""Download and extract from the task assets the data for the window from each asset.
94+
95+
Args:
96+
task (ExtractionTask): The extraction task
97+
resolution (int): The target resolution
98+
99+
Returns:
100+
List[str]: A list of files that store the crops of the original assets
101+
"""
102+
103+
# task tiles all have same CRS, so get their max extents and crs
104+
left, top, right, bottom = get_proj_win(task.tiles)
105+
epsg = task.tiles[0].epsg
106+
107+
# set the transforms for the output file
108+
dst_transform = Affine(resolution, 0.0, left, 0.0, -resolution, top)
109+
out_shp = (int((right - left) / resolution), int((top - bottom) / resolution))
110+
111+
outfiles = []
112+
113+
band = task.band
114+
urls = [item.assets[band].href for item in task.item_collection.items]
115+
116+
for ii, url in enumerate(urls):
117+
with fs.open(url) as f:
118+
with rasterio.open(f) as ds:
119+
window = get_window_union(task.tiles, ds)
120+
121+
rst_arr = ds.read(
122+
1,
123+
window=window,
124+
out_shape=out_shp,
125+
fill_value=0,
126+
boundless=True,
127+
resampling=Resampling.bilinear,
128+
)
129+
130+
out_f = f"{task.task_id}_{ii}.tif"
131+
132+
with rasterio.open(
133+
out_f,
134+
"w",
135+
driver="GTiff",
136+
count=1,
137+
width=out_shp[0],
138+
height=out_shp[1],
139+
transform=dst_transform,
140+
crs=CRS.from_epsg(epsg),
141+
dtype=rst_arr.dtype,
142+
) as dst:
143+
144+
dst.write(rst_arr, indexes=1)
145+
146+
outfiles.append(out_f)
147+
148+
return outfiles
149+
150+
51151
def download_and_extract_tiles_window(
52152
download_f: Callable,
53153
task: ExtractionTask,
@@ -112,7 +212,7 @@ def task_mosaic_patches(
112212
cloud_fs: Any,
113213
download_f: Callable,
114214
task: ExtractionTask,
115-
method: str = "first",
215+
method: str = "max",
116216
resolution: int = 10,
117217
dst_path="merged.jp2",
118218
) -> List[np.ndarray]:
@@ -121,14 +221,18 @@ def task_mosaic_patches(
121221
Args:
122222
download_f (Callable): The function to download the task assets
123223
task (ExtractionTask): The task
124-
method (str, optional): The method to use while merging the assets. Defaults to "first".
224+
method (str, optional): The method to use while merging the assets. Defaults to "max".
125225
resolution (int, optional): The target resolution. Defaults to 10.
126226
dst_path (str): path to store the merged files
127227
128228
Returns:
129229
List[np.ndarray]: The tile patches as numpy arrays
130230
"""
131-
out_files = download_and_extract_tiles_window(download_f, task, resolution)
231+
232+
if task.constellation == "sentinel-2":
233+
out_files = download_and_extract_tiles_window(download_f, task, resolution)
234+
else:
235+
out_files = download_and_extract_tiles_window_COG(cloud_fs, task, resolution)
132236

133237
out_f = f"{task.task_id}_{dst_path}"
134238
datasets = [rasterio.open(f) for f in out_files]

0 commit comments

Comments
 (0)