mirror of
https://github.com/qurator-spk/sbb_binarization.git
synced 2025-06-09 12:19:56 +02:00
Improving comments in the code
This commit is contained in:
parent
4112c6fe71
commit
1baa70e382
1 changed files with 4 additions and 2 deletions
|
@ -44,7 +44,7 @@ class SbbBinarizer:
|
|||
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
|
||||
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]
|
||||
|
||||
image_batch = np.expand_dims(padded_image, 0) # To create the batch information
|
||||
image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension
|
||||
patches = tf.image.extract_patches(
|
||||
images=image_batch,
|
||||
sizes=[1, self.model_height, self.model_width, 1],
|
||||
|
@ -117,6 +117,7 @@ def split_list_into_worker_batches(files: List[Any], number_of_workers: int) ->
|
|||
def batch_predict(input_data):
|
||||
model_dir, input_images, output_images, worker_number = input_data
|
||||
print(f"Setting visible cuda devices to {str(worker_number)}")
|
||||
# Each worker thread will be assigned only one of the available GPUs to allow multiprocessing across GPUs
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_number)
|
||||
|
||||
binarizer = SbbBinarizer()
|
||||
|
@ -146,13 +147,14 @@ if __name__ == '__main__':
|
|||
output_images = [output_path / (i.relative_to(input_path)) for i in input_images]
|
||||
input_images = [i for i in input_images]
|
||||
|
||||
print(f"Starting binarization of {len(input_images)} images")
|
||||
print(f"Starting batch-binarization of {len(input_images)} images")
|
||||
|
||||
number_of_gpus = len(tf.config.list_physical_devices('GPU'))
|
||||
number_of_workers = max(1, number_of_gpus)
|
||||
image_batches = split_list_into_worker_batches(input_images, number_of_workers)
|
||||
output_batches = split_list_into_worker_batches(output_images, number_of_workers)
|
||||
|
||||
# Must use spawn to create completely new process that has its own resources to properly multiprocess across GPUs
|
||||
with WorkerPool(n_jobs=number_of_workers, start_method='spawn') as pool:
|
||||
model_dirs = itertools.repeat(model_directory, len(image_batches))
|
||||
input_data = zip(model_dirs, image_batches, output_batches, range(number_of_workers))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue