Skip to content

Commit e8972f5

Browse files
committed
Improved loading of models to allow providing a directory, added a few type-hints and improved the code-style a little bit by running an auto-formatter on the entire file.
1 parent c99be55 commit e8972f5

File tree

1 file changed

+57
-59
lines changed

1 file changed

+57
-59
lines changed

sbb_binarize/sbb_binarize.py

Lines changed: 57 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,42 @@
11
"""
22
Tool to load model and binarize a given image.
33
"""
4-
4+
import argparse
55
import sys
6-
from glob import glob
76
from os import environ, devnull
8-
from os.path import join
9-
from warnings import catch_warnings, simplefilter
7+
from pathlib import Path
8+
from typing import Union
109

11-
import numpy as np
12-
from PIL import Image
1310
import cv2
11+
import numpy as np
12+
1413
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
1514
stderr = sys.stderr
1615
sys.stderr = open(devnull, 'w')
1716
import tensorflow as tf
1817
from tensorflow.keras.models import load_model
1918
from tensorflow.python.keras import backend as tensorflow_backend
20-
sys.stderr = stderr
2119

20+
sys.stderr = stderr
2221

2322
import logging
2423

24+
2525
def resize_image(img_in, input_height, input_width):
2626
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
2727

28+
2829
class SbbBinarizer:
2930

30-
def __init__(self, model_dir, logger=None):
31-
self.model_dir = model_dir
31+
def __init__(self, model_dir: Union[str, Path], logger=None):
32+
model_dir = Path(model_dir)
3233
self.log = logger if logger else logging.getLogger('SbbBinarizer')
3334

3435
self.start_new_session()
3536

36-
self.model_files = glob('%s/*.h5' % self.model_dir)
37+
self.model_files = list([str(p.absolute()) for p in model_dir.rglob("*.h5")])
38+
if not self.model_files:
39+
raise ValueError(f"No models found in {str(model_dir)}")
3740

3841
self.models = []
3942
for model_file in self.model_files:
@@ -51,54 +54,51 @@ def end_session(self):
5154
self.session.close()
5255
del self.session
5356

54-
def load_model(self, model_name):
55-
model = load_model(join(self.model_dir, model_name), compile=False)
56-
model_height = model.layers[len(model.layers)-1].output_shape[1]
57-
model_width = model.layers[len(model.layers)-1].output_shape[2]
58-
n_classes = model.layers[len(model.layers)-1].output_shape[3]
57+
def load_model(self, model_path: str):
58+
model = load_model(model_path, compile=False)
59+
model_height = model.layers[len(model.layers) - 1].output_shape[1]
60+
model_width = model.layers[len(model.layers) - 1].output_shape[2]
61+
n_classes = model.layers[len(model.layers) - 1].output_shape[3]
5962
return model, model_height, model_width, n_classes
6063

6164
def predict(self, model_in, img, use_patches):
6265
tensorflow_backend.set_session(self.session)
6366
model, model_height, model_width, n_classes = model_in
64-
67+
6568
img_org_h = img.shape[0]
6669
img_org_w = img.shape[1]
67-
70+
6871
if img.shape[0] < model_height and img.shape[1] >= model_width:
69-
img_padded = np.zeros(( model_height, img.shape[1], img.shape[2] ))
70-
71-
index_start_h = int( abs( img.shape[0] - model_height) /2.)
72+
img_padded = np.zeros((model_height, img.shape[1], img.shape[2]))
73+
74+
index_start_h = int(abs(img.shape[0] - model_height) / 2.)
7275
index_start_w = 0
73-
74-
img_padded [ index_start_h: index_start_h+img.shape[0], :, : ] = img[:,:,:]
75-
76+
77+
img_padded[index_start_h: index_start_h + img.shape[0], :, :] = img[:, :, :]
78+
7679
elif img.shape[0] >= model_height and img.shape[1] < model_width:
77-
img_padded = np.zeros(( img.shape[0], model_width, img.shape[2] ))
78-
79-
index_start_h = 0
80-
index_start_w = int( abs( img.shape[1] - model_width) /2.)
81-
82-
img_padded [ :, index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
83-
84-
80+
img_padded = np.zeros((img.shape[0], model_width, img.shape[2]))
81+
82+
index_start_h = 0
83+
index_start_w = int(abs(img.shape[1] - model_width) / 2.)
84+
85+
img_padded[:, index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
86+
87+
8588
elif img.shape[0] < model_height and img.shape[1] < model_width:
86-
img_padded = np.zeros(( model_height, model_width, img.shape[2] ))
87-
88-
index_start_h = int( abs( img.shape[0] - model_height) /2.)
89-
index_start_w = int( abs( img.shape[1] - model_width) /2.)
90-
91-
img_padded [ index_start_h: index_start_h+img.shape[0], index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
92-
89+
img_padded = np.zeros((model_height, model_width, img.shape[2]))
90+
91+
index_start_h = int(abs(img.shape[0] - model_height) / 2.)
92+
index_start_w = int(abs(img.shape[1] - model_width) / 2.)
93+
94+
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
95+
9396
else:
9497
index_start_h = 0
95-
index_start_w = 0
98+
index_start_w = 0
9699
img_padded = np.copy(img)
97-
98-
100+
99101
img = np.copy(img_padded)
100-
101-
102102

103103
if use_patches:
104104

@@ -107,7 +107,6 @@ def predict(self, model_in, img, use_patches):
107107
width_mid = model_width - 2 * margin
108108
height_mid = model_height - 2 * margin
109109

110-
111110
img = img / float(255.0)
112111

113112
img_h = img.shape[0]
@@ -167,49 +166,49 @@ def predict(self, model_in, img, use_patches):
167166
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
168167
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color
169168

170-
elif i == nxf-1 and j == nyf-1:
169+
elif i == nxf - 1 and j == nyf - 1:
171170
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - 0, :]
172171
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - 0]
173172

174173
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0] = seg
175174
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0, :] = seg_color
176175

177-
elif i == 0 and j == nyf-1:
176+
elif i == 0 and j == nyf - 1:
178177
seg_color = seg_color[margin:seg_color.shape[0] - 0, 0:seg_color.shape[1] - margin, :]
179178
seg = seg[margin:seg.shape[0] - 0, 0:seg.shape[1] - margin]
180179

181180
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin] = seg
182181
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin, :] = seg_color
183182

184-
elif i == nxf-1 and j == 0:
183+
elif i == nxf - 1 and j == 0:
185184
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
186185
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - 0]
187186

188187
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
189188
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color
190189

191-
elif i == 0 and j != 0 and j != nyf-1:
190+
elif i == 0 and j != 0 and j != nyf - 1:
192191
seg_color = seg_color[margin:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :]
193192
seg = seg[margin:seg.shape[0] - margin, 0:seg.shape[1] - margin]
194193

195194
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
196195
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color
197196

198-
elif i == nxf-1 and j != 0 and j != nyf-1:
197+
elif i == nxf - 1 and j != 0 and j != nyf - 1:
199198
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
200199
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - 0]
201200

202201
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
203202
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color
204203

205-
elif i != 0 and i != nxf-1 and j == 0:
204+
elif i != 0 and i != nxf - 1 and j == 0:
206205
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :]
207206
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - margin]
208207

209208
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
210209
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
211210

212-
elif i != 0 and i != nxf-1 and j == nyf-1:
211+
elif i != 0 and i != nxf - 1 and j == nyf - 1:
213212
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - margin, :]
214213
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - margin]
215214

@@ -222,10 +221,8 @@ def predict(self, model_in, img, use_patches):
222221

223222
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
224223
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
225-
226-
227-
228-
prediction_true = prediction_true[index_start_h: index_start_h+img_org_h, index_start_w: index_start_w+img_org_w,:]
224+
225+
prediction_true = prediction_true[index_start_h: index_start_h + img_org_h, index_start_w: index_start_w + img_org_w, :]
229226
prediction_true = prediction_true.astype(np.uint8)
230227

231228
else:
@@ -240,17 +237,16 @@ def predict(self, model_in, img, use_patches):
240237
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
241238
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
242239
prediction_true = prediction_true.astype(np.uint8)
243-
return prediction_true[:,:,0]
240+
return prediction_true[:, :, 0]
244241

245242
def run(self, image=None, image_path=None, save=None, use_patches=False):
246-
if (image is not None and image_path is not None) or \
247-
(image is None and image_path is None):
243+
if (image is not None and image_path is not None) or (image is None and image_path is None):
248244
raise ValueError("Must pass either a opencv2 image or an image_path")
249245
if image_path is not None:
250246
image = cv2.imread(image_path)
251247
img_last = 0
252248
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
253-
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
249+
self.log.info(f"Predicting with model {model_file} [{n + 1}/{len(self.model_files)}]")
254250

255251
res = self.predict(model, image, use_patches)
256252

@@ -270,5 +266,7 @@ def run(self, image=None, image_path=None, save=None, use_patches=False):
270266
img_last[:, :][img_last[:, :] > 0] = 255
271267
img_last = (img_last[:, :] == 0) * 255
272268
if save:
269+
# Create the output directory (and if necessary it's parents) if it doesn't exist already
270+
Path(save).parent.mkdir(parents=True, exist_ok=True)
273271
cv2.imwrite(save, img_last)
274272
return img_last

0 commit comments

Comments
 (0)