Skip to content

Commit 39e2607

Browse files
committed
Extract tiles from an image at random
Has functionality to constrain the tiles based on their RMS or latitude
1 parent 0b1fad1 commit 39e2607

File tree

5 files changed

+1746
-1284
lines changed

5 files changed

+1746
-1284
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ authors = [
88
]
99

1010
requires-python = ">=3.10"
11+
dependencies = [
12+
"nbdime>=4.0.2",
13+
]
1114

1215
[tool.uv]
1316
default-groups = ["core", "test"]
@@ -43,4 +46,4 @@ dev = [
4346

4447
[build-system]
4548
requires = ["hatchling"]
46-
build-backend = "hatchling.build"
49+
build-backend = "hatchling.build"

src/current_denoising/generation/ioutils.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,155 @@ def read_currents(path: pathlib.Path) -> np.ndarray:
3535
raise IOError("Close marker does not match the opener.")
3636

3737
return data.reshape(shape)
38+
39+
40+
def _included_indices(
41+
n_rows: int, tile_size: int, max_latitude: float
42+
) -> tuple[int, int]:
43+
"""
44+
Find the range of y-indices to select from, given the size of the input image
45+
and the maximum latitude.
46+
47+
Assumes the input image is centred on the equator and ranges from -90 to 90 degrees latitude.
48+
"""
49+
if max_latitude <= 0:
50+
raise IOError("Maximum latitude must be > 0")
51+
52+
# Given the number of rows in the image, find which latitude each row corresponds to
53+
latitudes = np.linspace(90, -90, n_rows, endpoint=True)
54+
55+
# Check if we have enough allowed latitudes
56+
allowed_latitudes = np.sum((latitudes < max_latitude) & (latitudes > -max_latitude))
57+
if allowed_latitudes < tile_size:
58+
raise IOError(
59+
f"Not enough allowed latitudes ({allowed_latitudes}) to fit a tile of size {tile_size}"
60+
)
61+
62+
# Find the first latitude that is <= to the provided maximum
63+
for min_row in range(n_rows):
64+
if max_latitude >= latitudes[min_row]:
65+
break
66+
else:
67+
raise RuntimeError("No rows found below the provided maximum latitude")
68+
69+
# Find the last latitude that is >= - the provided minimum
70+
for max_row in range(min_row + 1, n_rows - tile_size + 2):
71+
index = max_row + tile_size - 1
72+
# If we have fallen off the bottom of the image, take the last row
73+
if index == n_rows:
74+
break
75+
76+
# If the bottom of the tile lies on exactly the threshold, take this row
77+
if latitudes[index] == -max_latitude:
78+
break
79+
80+
# If the bottom of the tile is less than the threshold, take the previous row
81+
if latitudes[index] < -max_latitude:
82+
max_row -= 1
83+
break
84+
# Don't raise here if we don't find a row above the max latitude - that's fine, we'll just fall
85+
# through and take the last row
86+
87+
return min_row, max_row
88+
89+
90+
def _tile_index(
91+
rng, *, input_size: tuple[int, int], max_latitude: float, tile_size: int
92+
) -> tuple[int, int]:
93+
"""
94+
Generate random (y, x) indices for the top-left corner of a tile within an image
95+
96+
:param input_img_size: The size of the input image
97+
:param tile_size: The size of each tile.
98+
:param max_latitude: The maximum latitude for the tiles;
99+
will exclude tiles which extend above/below this latitude N/S.
100+
101+
:returns: A tuple of (y, x) indices
102+
"""
103+
height_range = _included_indices(input_size[0], tile_size, max_latitude)
104+
width_range = (0, input_size[1] - tile_size + 1)
105+
106+
y_index = int(rng.integers(*height_range))
107+
x_index = int(rng.integers(*width_range))
108+
109+
return y_index, x_index
110+
111+
112+
def _tile(input_img: np.ndarray, start: tuple[int, int], size: int) -> np.ndarray:
113+
"""
114+
Extract a tile at the provided location + size from a 2d array
115+
"""
116+
return input_img[start[0] : start[0] + size, start[1] : start[1] + size]
117+
118+
119+
def _tile_rms(tile: np.ndarray) -> float:
120+
"""
121+
Calculate the RMS of a tile
122+
123+
:param tile: the input tile
124+
:returns: the RMS of the tile
125+
"""
126+
return np.sqrt(np.mean(tile**2))
127+
128+
129+
def extract_tiles(
130+
rng: np.random.Generator,
131+
input_img: np.ndarray,
132+
*,
133+
num_tiles: int,
134+
max_rms: float,
135+
max_latitude: float = 64.0,
136+
tile_size: int = 32,
137+
) -> np.ndarray:
138+
"""
139+
Randomly extract tiles from an input image.
140+
141+
Tiles are selected such that no part of any tile exceeds the provided maximum latitude;
142+
this assumes the input image is centred on the equator and ranges from -90 to 90 degrees latitude.
143+
144+
:param rng: a seeded numpy random number generator
145+
:param input_img: The input image from which to extract tiles.
146+
:param tile_size: The size of each tile.
147+
:param num_tiles: The number of tiles to extract.
148+
:param max_rms: Maximum allowed RMS value in a tile
149+
:param max_latitude: The maximum latitude for the tiles;
150+
will exclude tiles which extend above/below this latitude N/S.
151+
152+
:returns: A numpy array containing the extracted tiles.
153+
:raises IOError: if the input image is not 2d
154+
:raises IOError: if the input image is smaller than the tile size
155+
"""
156+
if input_img.ndim != 2:
157+
raise IOError(f"Input image must be 2d; got shape {input_img.shape}")
158+
if input_img.shape[0] < tile_size or input_img.shape[1] < tile_size:
159+
raise IOError(
160+
f"Tile size must be smaller than image size; got {input_img.shape} but {tile_size=}"
161+
)
162+
163+
# Choose the range of indices to pick from
164+
height_range = slice(0, input_img.shape[0] - tile_size + 1)
165+
width_range = slice(0, input_img.shape[1] - tile_size + 1)
166+
167+
tiles = np.empty((num_tiles, tile_size, tile_size), dtype=input_img.dtype)
168+
indices_found = 0
169+
while indices_found < num_tiles:
170+
y, x = _tile_index(
171+
rng,
172+
input_size=input_img.shape,
173+
max_latitude=max_latitude,
174+
tile_size=tile_size,
175+
)
176+
177+
tile = _tile(input_img, (y, x), tile_size)
178+
179+
if tile.shape != (tile_size, tile_size):
180+
raise IOError(
181+
f"Extracted tile has wrong shape {tile.shape}, expected {(tile_size, tile_size)}"
182+
)
183+
184+
# Check the RMS of the tile
185+
if _tile_rms(tile) < max_rms:
186+
tiles[indices_found] = tile
187+
indices_found += 1
188+
189+
return tiles

src/notebooks/read_dat.ipynb

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 1,
66
"id": "1612eebf",
77
"metadata": {},
88
"outputs": [],
@@ -79,6 +79,32 @@
7979
"id": "dd19669a",
8080
"metadata": {},
8181
"outputs": [],
82+
"source": [
83+
"from current_denoising.generation.ioutils import extract_tiles\n",
84+
"\n",
85+
"rng = np.random.default_rng(1234)\n",
86+
"\n",
87+
"tiles = extract_tiles(rng, data, num_tiles=16, max_rms=np.inf, max_latitude=10.0, tile_size=32)"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "1131dd37",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"fig, axes = plt.subplots(4, 4, figsize=(12, 12))\n",
98+
"for axis, tile in zip(axes.flat, tiles):\n",
99+
" im = axis.imshow(tile, origin=\"lower\", norm=\"log\", vmin=np.nanmin(data), vmax=np.nanmax(data))"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"id": "def15562",
106+
"metadata": {},
107+
"outputs": [],
82108
"source": []
83109
}
84110
],

0 commit comments

Comments
 (0)