@@ -44,7 +44,7 @@ def binarize_image(self, image_path: Path, save_path: Path):
44
44
padded_image = np .zeros ((padded_image_height , padded_image_width , image_channels ))
45
45
padded_image [0 :original_image_height , 0 :original_image_width , :] = img [:, :, :]
46
46
47
- image_batch = np .expand_dims (padded_image , 0 ) # To create the batch information
47
+ image_batch = np .expand_dims (padded_image , 0 ) # Create the batch dimension
48
48
patches = tf .image .extract_patches (
49
49
images = image_batch ,
50
50
sizes = [1 , self .model_height , self .model_width , 1 ],
@@ -117,6 +117,7 @@ def split_list_into_worker_batches(files: List[Any], number_of_workers: int) ->
117
117
def batch_predict (input_data ):
118
118
model_dir , input_images , output_images , worker_number = input_data
119
119
print (f"Setting visible cuda devices to { str (worker_number )} " )
120
+ # Each worker thread will be assigned only one of the available GPUs to allow multiprocessing across GPUs
120
121
os .environ ["CUDA_VISIBLE_DEVICES" ] = str (worker_number )
121
122
122
123
binarizer = SbbBinarizer ()
@@ -146,13 +147,14 @@ def batch_predict(input_data):
146
147
output_images = [output_path / (i .relative_to (input_path )) for i in input_images ]
147
148
input_images = [i for i in input_images ]
148
149
149
- print (f"Starting binarization of { len (input_images )} images" )
150
+ print (f"Starting batch- binarization of { len (input_images )} images" )
150
151
151
152
number_of_gpus = len (tf .config .list_physical_devices ('GPU' ))
152
153
number_of_workers = max (1 , number_of_gpus )
153
154
image_batches = split_list_into_worker_batches (input_images , number_of_workers )
154
155
output_batches = split_list_into_worker_batches (output_images , number_of_workers )
155
156
157
+ # Must use spawn to create completely new process that has its own resources to properly multiprocess across GPUs
156
158
with WorkerPool (n_jobs = number_of_workers , start_method = 'spawn' ) as pool :
157
159
model_dirs = itertools .repeat (model_directory , len (image_batches ))
158
160
input_data = zip (model_dirs , image_batches , output_batches , range (number_of_workers ))
0 commit comments