@@ -35,3 +35,155 @@ def read_currents(path: pathlib.Path) -> np.ndarray:
3535 raise IOError ("Close marker does not match the opener." )
3636
3737 return data .reshape (shape )
38+
39+
40+ def _included_indices (
41+ n_rows : int , tile_size : int , max_latitude : float
42+ ) -> tuple [int , int ]:
43+ """
44+ Find the range of y-indices to select from, given the size of the input image
45+ and the maximum latitude.
46+
47+ Assumes the input image is centred on the equator and ranges from -90 to 90 degrees latitude.
48+ """
49+ if max_latitude <= 0 :
50+ raise IOError ("Maximum latitude must be > 0" )
51+
52+ # Given the number of rows in the image, find which latitude each row corresponds to
53+ latitudes = np .linspace (90 , - 90 , n_rows , endpoint = True )
54+
55+ # Check if we have enough allowed latitudes
56+ allowed_latitudes = np .sum ((latitudes < max_latitude ) & (latitudes > - max_latitude ))
57+ if allowed_latitudes < tile_size :
58+ raise IOError (
59+ f"Not enough allowed latitudes ({ allowed_latitudes } ) to fit a tile of size { tile_size } "
60+ )
61+
62+ # Find the first latitude that is <= to the provided maximum
63+ for min_row in range (n_rows ):
64+ if max_latitude >= latitudes [min_row ]:
65+ break
66+ else :
67+ raise RuntimeError ("No rows found below the provided maximum latitude" )
68+
69+ # Find the last latitude that is >= - the provided minimum
70+ for max_row in range (min_row + 1 , n_rows - tile_size + 2 ):
71+ index = max_row + tile_size - 1
72+ # If we have fallen off the bottom of the image, take the last row
73+ if index == n_rows :
74+ break
75+
76+ # If the bottom of the tile lies on exactly the threshold, take this row
77+ if latitudes [index ] == - max_latitude :
78+ break
79+
80+ # If the bottom of the tile is less than the threshold, take the previous row
81+ if latitudes [index ] < - max_latitude :
82+ max_row -= 1
83+ break
84+ # Don't raise here if we don't find a row above the max latitude - that's fine, we'll just fall
85+ # through and take the last row
86+
87+ return min_row , max_row
88+
89+
90+ def _tile_index (
91+ rng , * , input_size : tuple [int , int ], max_latitude : float , tile_size : int
92+ ) -> tuple [int , int ]:
93+ """
94+ Generate random (y, x) indices for the top-left corner of a tile within an image
95+
96+ :param input_img_size: The size of the input image
97+ :param tile_size: The size of each tile.
98+ :param max_latitude: The maximum latitude for the tiles;
99+ will exclude tiles which extend above/below this latitude N/S.
100+
101+ :returns: A tuple of (y, x) indices
102+ """
103+ height_range = _included_indices (input_size [0 ], tile_size , max_latitude )
104+ width_range = (0 , input_size [1 ] - tile_size + 1 )
105+
106+ y_index = int (rng .integers (* height_range ))
107+ x_index = int (rng .integers (* width_range ))
108+
109+ return y_index , x_index
110+
111+
112+ def _tile (input_img : np .ndarray , start : tuple [int , int ], size : int ) -> np .ndarray :
113+ """
114+ Extract a tile at the provided location + size from a 2d array
115+ """
116+ return input_img [start [0 ] : start [0 ] + size , start [1 ] : start [1 ] + size ]
117+
118+
119+ def _tile_rms (tile : np .ndarray ) -> float :
120+ """
121+ Calculate the RMS of a tile
122+
123+ :param tile: the input tile
124+ :returns: the RMS of the tile
125+ """
126+ return np .sqrt (np .mean (tile ** 2 ))
127+
128+
129+ def extract_tiles (
130+ rng : np .random .Generator ,
131+ input_img : np .ndarray ,
132+ * ,
133+ num_tiles : int ,
134+ max_rms : float ,
135+ max_latitude : float = 64.0 ,
136+ tile_size : int = 32 ,
137+ ) -> np .ndarray :
138+ """
139+ Randomly extract tiles from an input image.
140+
141+ Tiles are selected such that no part of any tile exceeds the provided maximum latitude;
142+ this assumes the input image is centred on the equator and ranges from -90 to 90 degrees latitude.
143+
144+ :param rng: a seeded numpy random number generator
145+ :param input_img: The input image from which to extract tiles.
146+ :param tile_size: The size of each tile.
147+ :param num_tiles: The number of tiles to extract.
148+ :param max_rms: Maximum allowed RMS value in a tile
149+ :param max_latitude: The maximum latitude for the tiles;
150+ will exclude tiles which extend above/below this latitude N/S.
151+
152+ :returns: A numpy array containing the extracted tiles.
153+ :raises IOError: if the input image is not 2d
154+ :raises IOError: if the input image is smaller than the tile size
155+ """
156+ if input_img .ndim != 2 :
157+ raise IOError (f"Input image must be 2d; got shape { input_img .shape } " )
158+ if input_img .shape [0 ] < tile_size or input_img .shape [1 ] < tile_size :
159+ raise IOError (
160+ f"Tile size must be smaller than image size; got { input_img .shape } but { tile_size = } "
161+ )
162+
163+ # Choose the range of indices to pick from
164+ height_range = slice (0 , input_img .shape [0 ] - tile_size + 1 )
165+ width_range = slice (0 , input_img .shape [1 ] - tile_size + 1 )
166+
167+ tiles = np .empty ((num_tiles , tile_size , tile_size ), dtype = input_img .dtype )
168+ indices_found = 0
169+ while indices_found < num_tiles :
170+ y , x = _tile_index (
171+ rng ,
172+ input_size = input_img .shape ,
173+ max_latitude = max_latitude ,
174+ tile_size = tile_size ,
175+ )
176+
177+ tile = _tile (input_img , (y , x ), tile_size )
178+
179+ if tile .shape != (tile_size , tile_size ):
180+ raise IOError (
181+ f"Extracted tile has wrong shape { tile .shape } , expected { (tile_size , tile_size )} "
182+ )
183+
184+ # Check the RMS of the tile
185+ if _tile_rms (tile ) < max_rms :
186+ tiles [indices_found ] = tile
187+ indices_found += 1
188+
189+ return tiles
0 commit comments