@@ -44,7 +44,7 @@ def binarize_image(self, image_path: Path, save_path: Path):
4444 padded_image = np .zeros ((padded_image_height , padded_image_width , image_channels ))
4545 padded_image [0 :original_image_height , 0 :original_image_width , :] = img [:, :, :]
4646
47- image_batch = np .expand_dims (padded_image , 0 ) # To create the batch information
47+ image_batch = np .expand_dims (padded_image , 0 ) # Create the batch dimension
4848 patches = tf .image .extract_patches (
4949 images = image_batch ,
5050 sizes = [1 , self .model_height , self .model_width , 1 ],
@@ -117,6 +117,7 @@ def split_list_into_worker_batches(files: List[Any], number_of_workers: int) ->
117117def batch_predict (input_data ):
118118 model_dir , input_images , output_images , worker_number = input_data
119119 print (f"Setting visible cuda devices to { str (worker_number )} " )
120+ # Each worker thread will be assigned only one of the available GPUs to allow multiprocessing across GPUs
120121 os .environ ["CUDA_VISIBLE_DEVICES" ] = str (worker_number )
121122
122123 binarizer = SbbBinarizer ()
@@ -146,13 +147,14 @@ def batch_predict(input_data):
146147 output_images = [output_path / (i .relative_to (input_path )) for i in input_images ]
147148 input_images = [i for i in input_images ]
148149
149- print (f"Starting binarization of { len (input_images )} images" )
150+ print (f"Starting batch- binarization of { len (input_images )} images" )
150151
151152 number_of_gpus = len (tf .config .list_physical_devices ('GPU' ))
152153 number_of_workers = max (1 , number_of_gpus )
153154 image_batches = split_list_into_worker_batches (input_images , number_of_workers )
154155 output_batches = split_list_into_worker_batches (output_images , number_of_workers )
155156
157+ # Must use spawn to create completely new process that has its own resources to properly multiprocess across GPUs
156158 with WorkerPool (n_jobs = number_of_workers , start_method = 'spawn' ) as pool :
157159 model_dirs = itertools .repeat (model_directory , len (image_batches ))
158160 input_data = zip (model_dirs , image_batches , output_batches , range (number_of_workers ))
0 commit comments