From 1baa70e38221f231ce7a4b18ff3a6f5291323df1 Mon Sep 17 00:00:00 2001 From: Alexander Pacha Date: Tue, 30 Aug 2022 10:47:29 +0200 Subject: [PATCH] Improving comments in the code --- sbb_binarize/sbb_binarize.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sbb_binarize/sbb_binarize.py b/sbb_binarize/sbb_binarize.py index a7a5dd3..c49849c 100644 --- a/sbb_binarize/sbb_binarize.py +++ b/sbb_binarize/sbb_binarize.py @@ -44,7 +44,7 @@ class SbbBinarizer: padded_image = np.zeros((padded_image_height, padded_image_width, image_channels)) padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :] - image_batch = np.expand_dims(padded_image, 0) # To create the batch information + image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension patches = tf.image.extract_patches( images=image_batch, sizes=[1, self.model_height, self.model_width, 1], @@ -117,6 +117,7 @@ def split_list_into_worker_batches(files: List[Any], number_of_workers: int) -> def batch_predict(input_data): model_dir, input_images, output_images, worker_number = input_data print(f"Setting visible cuda devices to {str(worker_number)}") + # Each worker thread will be assigned only one of the available GPUs to allow multiprocessing across GPUs os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_number) binarizer = SbbBinarizer() @@ -146,13 +147,14 @@ if __name__ == '__main__': output_images = [output_path / (i.relative_to(input_path)) for i in input_images] input_images = [i for i in input_images] - print(f"Starting binarization of {len(input_images)} images") + print(f"Starting batch-binarization of {len(input_images)} images") number_of_gpus = len(tf.config.list_physical_devices('GPU')) number_of_workers = max(1, number_of_gpus) image_batches = split_list_into_worker_batches(input_images, number_of_workers) output_batches = split_list_into_worker_batches(output_images, number_of_workers) + # Must use spawn to create completely new process that has its own resources to properly multiprocess across GPUs with WorkerPool(n_jobs=number_of_workers, start_method='spawn') as pool: model_dirs = itertools.repeat(model_directory, len(image_batches)) input_data = zip(model_dirs, image_batches, output_batches, range(number_of_workers))